import io import math import os import struct import threading import time from functools import lru_cache import torch PROB_SCALE = 1 << 48 ARITHMETIC_PRECISION = 64 class BitOutputStream: def __init__(self, file_obj): self.file_obj = file_obj self.byte = 0 self.bit_count = 0 def write_bit(self, bit): self.byte = (self.byte << 1) | bit self.bit_count += 1 if self.bit_count == 8: self.file_obj.write(bytes([self.byte])) self.byte = 0 self.bit_count = 0 def close(self): if self.bit_count > 0: self.byte <<= 8 - self.bit_count self.file_obj.write(bytes([self.byte])) class BitInputStream: def __init__(self, file_obj): self.file_obj = file_obj self.byte = 0 self.bit_count = 0 def read_bit(self): if self.bit_count == 0: bytes_data = self.file_obj.read(1) if not bytes_data: return -1 self.byte = bytes_data[0] self.bit_count = 8 bit = (self.byte >> (self.bit_count - 1)) & 1 self.bit_count -= 1 return bit class ArithmeticEncoder: def __init__(self, bit_output, precision=ARITHMETIC_PRECISION): self.bit_output = bit_output self.precision = precision self.max_val = (1 << precision) - 1 self.quarter_val = 1 << (precision - 2) self.half_val = 1 << (precision - 1) self.three_quarter_val = self.quarter_val * 3 self.low = 0 self.high = self.max_val self.pending_bits = 0 def encode_symbol(self, low_count, high_count, total_count): range_val = self.high - self.low + 1 self.high = self.low + (range_val * high_count) // total_count - 1 self.low = self.low + (range_val * low_count) // total_count while True: if self.high < self.half_val: self._write_bit(0) elif self.low >= self.half_val: self._write_bit(1) self.low -= self.half_val self.high -= self.half_val elif self.low >= self.quarter_val and self.high < self.three_quarter_val: self.pending_bits += 1 self.low -= self.quarter_val self.high -= self.quarter_val else: break self.low <<= 1 self.high = (self.high << 1) | 1 def _write_bit(self, bit): self.bit_output.write_bit(bit) while self.pending_bits > 0: self.bit_output.write_bit(1 - bit) self.pending_bits -= 1 def finish(self): self.pending_bits += 1 if self.low < self.quarter_val: self._write_bit(0) else: self._write_bit(1) class ArithmeticDecoder: def __init__(self, bit_input, precision=ARITHMETIC_PRECISION): self.bit_input = bit_input self.precision = precision self.max_val = (1 << precision) - 1 self.quarter_val = 1 << (precision - 2) self.half_val = 1 << (precision - 1) self.three_quarter_val = self.quarter_val * 3 self.low = 0 self.high = self.max_val self.value = 0 for _ in range(precision): read_val = self.bit_input.read_bit() if read_val == -1: read_val = 0 self.value = (self.value << 1) | read_val def decode_symbol_find_count(self, total_count): range_val = self.high - self.low + 1 count = ((self.value - self.low + 1) * total_count - 1) // range_val return count def update_range(self, low_count, high_count, total_count): range_val = self.high - self.low + 1 self.high = self.low + (range_val * high_count) // total_count - 1 self.low = self.low + (range_val * low_count) // total_count while True: if self.high < self.half_val: pass elif self.low >= self.half_val: self.value -= self.half_val self.low -= self.half_val self.high -= self.half_val elif self.low >= self.quarter_val and self.high < self.three_quarter_val: self.value -= self.quarter_val self.low -= self.quarter_val self.high -= self.quarter_val else: break self.low <<= 1 self.high = (self.high << 1) | 1 bit = self.bit_input.read_bit() if bit == -1: bit = 0 self.value = (self.value << 1) | bit def _strip_pth(model_path): return model_path[:-4] if model_path.endswith(".pth") else model_path def _prepare_logits(logits): if not isinstance(logits, torch.Tensor): logits = torch.tensor(logits) if logits.ndim > 1: logits = logits[-1] return logits.float() def tokenize_text(tokenizer, text): tokenized = tokenizer.encode(text) if hasattr(tokenized, "ids"): tokenized = tokenized.ids return [int(token_id) for token_id in tokenized] def decode_tokens(tokenizer, tokens): return tokenizer.decode(tokens) _MODEL_LOCK = threading.Lock() @lru_cache(maxsize=2) def load_rwkv_model(model_path, tokenizer_name, strategy): if not model_path: raise ValueError("RWKV model path is required.") if not tokenizer_name: raise ValueError("RWKV tokenizer name or path is required.") if "cuda" in strategy and not torch.cuda.is_available(): strategy = "cpu fp32" os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_V7_ON"] = "1" os.environ["RWKV_CUDA_ON"] = "1" if "cuda" in strategy else "0" with _MODEL_LOCK: from rwkv.model import RWKV from rwkv.rwkv_tokenizer import TRIE_TOKENIZER model = RWKV(model=_strip_pth(model_path), strategy=strategy) tokenizer = TRIE_TOKENIZER(tokenizer_name) return model, tokenizer def compress_tokens( tokens, model, context_window=2048, original_bytes=None, progress=None, progress_desc="Compressing", ): if context_window <= 0: raise ValueError("context_window must be positive.") token_ids = [int(token_id) for token_id in tokens] if not token_ids: raise ValueError("No tokens to compress.") output = io.BytesIO() output.write(struct.pack(">I", len(token_ids))) bit_output = BitOutputStream(output) encoder = ArithmeticEncoder(bit_output, precision=ARITHMETIC_PRECISION) context_tokens = [] state = None total_nll = 0.0 start_time = time.time() total_tokens = len(token_ids) if progress is not None: progress((0, total_tokens), desc=progress_desc, unit="token") with torch.inference_mode(): for idx, token_id in enumerate(token_ids): if len(context_tokens) >= context_window: context_tokens = [] state = None input_token = context_tokens[-1] if context_tokens else 0 logits, state = model.forward([input_token], state) next_logits = _prepare_logits(logits) probs = torch.softmax(next_logits, dim=-1) counts = (probs * PROB_SCALE).to(torch.long) counts = torch.clamp(counts, min=1) cdf = torch.cumsum(counts, dim=-1) total_count = int(cdf[-1].item()) prob_val = probs[token_id] total_nll += float((-torch.log(prob_val)).item()) low_val = int(cdf[token_id - 1].item()) if token_id > 0 else 0 high_val = int(cdf[token_id].item()) encoder.encode_symbol(low_val, high_val, total_count) context_tokens.append(token_id) if progress is not None: progress((idx + 1, total_tokens), desc=progress_desc, unit="token") encoder.finish() bit_output.close() data = output.getvalue() end_time = time.time() original_bytes = int(original_bytes or 0) compressed_bytes = len(data) ratio = compressed_bytes / original_bytes if original_bytes > 0 else 0.0 theoretical_bits = total_nll / math.log(2) theoretical_bytes = theoretical_bits / 8 theoretical_ratio = theoretical_bytes / original_bytes if original_bytes > 0 else 0.0 duration = end_time - start_time speed = len(token_ids) / duration if duration > 0 else 0.0 stats = { "tokens": len(token_ids), "original_bytes": original_bytes, "compressed_bytes": compressed_bytes, "ratio": ratio, "theoretical_ratio": theoretical_ratio, "duration_s": duration, "speed_toks_per_s": speed, } return data, stats def compress_text(text, model, tokenizer, context_window=2048): tokens = tokenize_text(tokenizer, text) original_bytes = len(text.encode("utf-8")) return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes) def decompress_bytes( data, model, tokenizer, context_window=2048, progress=None, progress_desc="Decompressing", ): if context_window <= 0: raise ValueError("context_window must be positive.") if not data or len(data) < 4: raise ValueError("Compressed data is empty or invalid.") buffer = io.BytesIO(data) total_tokens_bytes = buffer.read(4) total_tokens = struct.unpack(">I", total_tokens_bytes)[0] bit_input = BitInputStream(buffer) decoder = ArithmeticDecoder(bit_input, precision=ARITHMETIC_PRECISION) decoded_tokens = [] context_tokens = [] state = None start_time = time.time() if progress is not None: progress((0, total_tokens), desc=progress_desc, unit="token") progress_step = max(1, total_tokens // 100) with torch.inference_mode(): for idx in range(total_tokens): if len(context_tokens) >= context_window: context_tokens = [] state = None input_token = context_tokens[-1] if context_tokens else 0 logits, state = model.forward([input_token], state) next_logits = _prepare_logits(logits) probs = torch.softmax(next_logits, dim=-1) counts = (probs * PROB_SCALE).to(torch.long) counts = torch.clamp(counts, min=1) cdf = torch.cumsum(counts, dim=-1) total_count = int(cdf[-1].item()) count_val = decoder.decode_symbol_find_count(total_count) count_val_tensor = torch.tensor(count_val, device=cdf.device) target_token_id = int(torch.searchsorted(cdf, count_val_tensor, right=True).item()) decoded_tokens.append(target_token_id) context_tokens.append(target_token_id) low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0 high_val = int(cdf[target_token_id].item()) decoder.update_range(low_val, high_val, total_count) if progress is not None and (idx + 1 == total_tokens or (idx + 1) % progress_step == 0): progress((idx + 1, total_tokens), desc=progress_desc, unit="token") text = decode_tokens(tokenizer, decoded_tokens) duration = time.time() - start_time stats = { "tokens": total_tokens, "duration_s": duration, } return text, stats