Spaces:
Sleeping
Sleeping
| import io | |
| import math | |
| import os | |
| import struct | |
| import threading | |
| import time | |
| from functools import lru_cache | |
| import torch | |
| PROB_SCALE = 1 << 48 | |
| ARITHMETIC_PRECISION = 64 | |
| class BitOutputStream: | |
| def __init__(self, file_obj): | |
| self.file_obj = file_obj | |
| self.byte = 0 | |
| self.bit_count = 0 | |
| def write_bit(self, bit): | |
| self.byte = (self.byte << 1) | bit | |
| self.bit_count += 1 | |
| if self.bit_count == 8: | |
| self.file_obj.write(bytes([self.byte])) | |
| self.byte = 0 | |
| self.bit_count = 0 | |
| def close(self): | |
| if self.bit_count > 0: | |
| self.byte <<= 8 - self.bit_count | |
| self.file_obj.write(bytes([self.byte])) | |
| class BitInputStream: | |
| def __init__(self, file_obj): | |
| self.file_obj = file_obj | |
| self.byte = 0 | |
| self.bit_count = 0 | |
| def read_bit(self): | |
| if self.bit_count == 0: | |
| bytes_data = self.file_obj.read(1) | |
| if not bytes_data: | |
| return -1 | |
| self.byte = bytes_data[0] | |
| self.bit_count = 8 | |
| bit = (self.byte >> (self.bit_count - 1)) & 1 | |
| self.bit_count -= 1 | |
| return bit | |
| class ArithmeticEncoder: | |
| def __init__(self, bit_output, precision=ARITHMETIC_PRECISION): | |
| self.bit_output = bit_output | |
| self.precision = precision | |
| self.max_val = (1 << precision) - 1 | |
| self.quarter_val = 1 << (precision - 2) | |
| self.half_val = 1 << (precision - 1) | |
| self.three_quarter_val = self.quarter_val * 3 | |
| self.low = 0 | |
| self.high = self.max_val | |
| self.pending_bits = 0 | |
| def encode_symbol(self, low_count, high_count, total_count): | |
| range_val = self.high - self.low + 1 | |
| self.high = self.low + (range_val * high_count) // total_count - 1 | |
| self.low = self.low + (range_val * low_count) // total_count | |
| while True: | |
| if self.high < self.half_val: | |
| self._write_bit(0) | |
| elif self.low >= self.half_val: | |
| self._write_bit(1) | |
| self.low -= self.half_val | |
| self.high -= self.half_val | |
| elif self.low >= self.quarter_val and self.high < self.three_quarter_val: | |
| self.pending_bits += 1 | |
| self.low -= self.quarter_val | |
| self.high -= self.quarter_val | |
| else: | |
| break | |
| self.low <<= 1 | |
| self.high = (self.high << 1) | 1 | |
| def _write_bit(self, bit): | |
| self.bit_output.write_bit(bit) | |
| while self.pending_bits > 0: | |
| self.bit_output.write_bit(1 - bit) | |
| self.pending_bits -= 1 | |
| def finish(self): | |
| self.pending_bits += 1 | |
| if self.low < self.quarter_val: | |
| self._write_bit(0) | |
| else: | |
| self._write_bit(1) | |
| class ArithmeticDecoder: | |
| def __init__(self, bit_input, precision=ARITHMETIC_PRECISION): | |
| self.bit_input = bit_input | |
| self.precision = precision | |
| self.max_val = (1 << precision) - 1 | |
| self.quarter_val = 1 << (precision - 2) | |
| self.half_val = 1 << (precision - 1) | |
| self.three_quarter_val = self.quarter_val * 3 | |
| self.low = 0 | |
| self.high = self.max_val | |
| self.value = 0 | |
| for _ in range(precision): | |
| read_val = self.bit_input.read_bit() | |
| if read_val == -1: | |
| read_val = 0 | |
| self.value = (self.value << 1) | read_val | |
| def decode_symbol_find_count(self, total_count): | |
| range_val = self.high - self.low + 1 | |
| count = ((self.value - self.low + 1) * total_count - 1) // range_val | |
| return count | |
| def update_range(self, low_count, high_count, total_count): | |
| range_val = self.high - self.low + 1 | |
| self.high = self.low + (range_val * high_count) // total_count - 1 | |
| self.low = self.low + (range_val * low_count) // total_count | |
| while True: | |
| if self.high < self.half_val: | |
| pass | |
| elif self.low >= self.half_val: | |
| self.value -= self.half_val | |
| self.low -= self.half_val | |
| self.high -= self.half_val | |
| elif self.low >= self.quarter_val and self.high < self.three_quarter_val: | |
| self.value -= self.quarter_val | |
| self.low -= self.quarter_val | |
| self.high -= self.quarter_val | |
| else: | |
| break | |
| self.low <<= 1 | |
| self.high = (self.high << 1) | 1 | |
| bit = self.bit_input.read_bit() | |
| if bit == -1: | |
| bit = 0 | |
| self.value = (self.value << 1) | bit | |
| def _strip_pth(model_path): | |
| return model_path[:-4] if model_path.endswith(".pth") else model_path | |
| def _prepare_logits(logits): | |
| if not isinstance(logits, torch.Tensor): | |
| logits = torch.tensor(logits) | |
| if logits.ndim > 1: | |
| logits = logits[-1] | |
| return logits.float() | |
| def tokenize_text(tokenizer, text): | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| tokenized = tokenized.ids | |
| return [int(token_id) for token_id in tokenized] | |
| def decode_tokens(tokenizer, tokens): | |
| return tokenizer.decode(tokens) | |
| _MODEL_LOCK = threading.Lock() | |
| def load_rwkv_model(model_path, tokenizer_name, strategy): | |
| if not model_path: | |
| raise ValueError("RWKV model path is required.") | |
| if not tokenizer_name: | |
| raise ValueError("RWKV tokenizer name or path is required.") | |
| if "cuda" in strategy and not torch.cuda.is_available(): | |
| strategy = "cpu fp32" | |
| os.environ["RWKV_JIT_ON"] = "1" | |
| os.environ["RWKV_V7_ON"] = "1" | |
| os.environ["RWKV_CUDA_ON"] = "1" if "cuda" in strategy else "0" | |
| with _MODEL_LOCK: | |
| from rwkv.model import RWKV | |
| from rwkv.rwkv_tokenizer import TRIE_TOKENIZER | |
| model = RWKV(model=_strip_pth(model_path), strategy=strategy) | |
| tokenizer = TRIE_TOKENIZER(tokenizer_name) | |
| return model, tokenizer | |
| def compress_tokens( | |
| tokens, | |
| model, | |
| context_window=2048, | |
| original_bytes=None, | |
| progress=None, | |
| progress_desc="Compressing", | |
| ): | |
| if context_window <= 0: | |
| raise ValueError("context_window must be positive.") | |
| token_ids = [int(token_id) for token_id in tokens] | |
| if not token_ids: | |
| raise ValueError("No tokens to compress.") | |
| output = io.BytesIO() | |
| output.write(struct.pack(">I", len(token_ids))) | |
| bit_output = BitOutputStream(output) | |
| encoder = ArithmeticEncoder(bit_output, precision=ARITHMETIC_PRECISION) | |
| context_tokens = [] | |
| state = None | |
| total_nll = 0.0 | |
| start_time = time.time() | |
| total_tokens = len(token_ids) | |
| if progress is not None: | |
| progress((0, total_tokens), desc=progress_desc, unit="token") | |
| with torch.inference_mode(): | |
| for idx, token_id in enumerate(token_ids): | |
| if len(context_tokens) >= context_window: | |
| context_tokens = [] | |
| state = None | |
| input_token = context_tokens[-1] if context_tokens else 0 | |
| logits, state = model.forward([input_token], state) | |
| next_logits = _prepare_logits(logits) | |
| probs = torch.softmax(next_logits, dim=-1) | |
| counts = (probs * PROB_SCALE).to(torch.long) | |
| counts = torch.clamp(counts, min=1) | |
| cdf = torch.cumsum(counts, dim=-1) | |
| total_count = int(cdf[-1].item()) | |
| prob_val = probs[token_id] | |
| total_nll += float((-torch.log(prob_val)).item()) | |
| low_val = int(cdf[token_id - 1].item()) if token_id > 0 else 0 | |
| high_val = int(cdf[token_id].item()) | |
| encoder.encode_symbol(low_val, high_val, total_count) | |
| context_tokens.append(token_id) | |
| if progress is not None: | |
| progress((idx + 1, total_tokens), desc=progress_desc, unit="token") | |
| encoder.finish() | |
| bit_output.close() | |
| data = output.getvalue() | |
| end_time = time.time() | |
| original_bytes = int(original_bytes or 0) | |
| compressed_bytes = len(data) | |
| ratio = compressed_bytes / original_bytes if original_bytes > 0 else 0.0 | |
| theoretical_bits = total_nll / math.log(2) | |
| theoretical_bytes = theoretical_bits / 8 | |
| theoretical_ratio = theoretical_bytes / original_bytes if original_bytes > 0 else 0.0 | |
| duration = end_time - start_time | |
| speed = len(token_ids) / duration if duration > 0 else 0.0 | |
| stats = { | |
| "tokens": len(token_ids), | |
| "original_bytes": original_bytes, | |
| "compressed_bytes": compressed_bytes, | |
| "ratio": ratio, | |
| "theoretical_ratio": theoretical_ratio, | |
| "duration_s": duration, | |
| "speed_toks_per_s": speed, | |
| } | |
| return data, stats | |
| def compress_text(text, model, tokenizer, context_window=2048): | |
| tokens = tokenize_text(tokenizer, text) | |
| original_bytes = len(text.encode("utf-8")) | |
| return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes) | |
| def decompress_bytes( | |
| data, | |
| model, | |
| tokenizer, | |
| context_window=2048, | |
| progress=None, | |
| progress_desc="Decompressing", | |
| ): | |
| if context_window <= 0: | |
| raise ValueError("context_window must be positive.") | |
| if not data or len(data) < 4: | |
| raise ValueError("Compressed data is empty or invalid.") | |
| buffer = io.BytesIO(data) | |
| total_tokens_bytes = buffer.read(4) | |
| total_tokens = struct.unpack(">I", total_tokens_bytes)[0] | |
| bit_input = BitInputStream(buffer) | |
| decoder = ArithmeticDecoder(bit_input, precision=ARITHMETIC_PRECISION) | |
| decoded_tokens = [] | |
| context_tokens = [] | |
| state = None | |
| start_time = time.time() | |
| if progress is not None: | |
| progress((0, total_tokens), desc=progress_desc, unit="token") | |
| progress_step = max(1, total_tokens // 100) | |
| with torch.inference_mode(): | |
| for idx in range(total_tokens): | |
| if len(context_tokens) >= context_window: | |
| context_tokens = [] | |
| state = None | |
| input_token = context_tokens[-1] if context_tokens else 0 | |
| logits, state = model.forward([input_token], state) | |
| next_logits = _prepare_logits(logits) | |
| probs = torch.softmax(next_logits, dim=-1) | |
| counts = (probs * PROB_SCALE).to(torch.long) | |
| counts = torch.clamp(counts, min=1) | |
| cdf = torch.cumsum(counts, dim=-1) | |
| total_count = int(cdf[-1].item()) | |
| count_val = decoder.decode_symbol_find_count(total_count) | |
| count_val_tensor = torch.tensor(count_val, device=cdf.device) | |
| target_token_id = int(torch.searchsorted(cdf, count_val_tensor, right=True).item()) | |
| decoded_tokens.append(target_token_id) | |
| context_tokens.append(target_token_id) | |
| low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0 | |
| high_val = int(cdf[target_token_id].item()) | |
| decoder.update_range(low_val, high_val, total_count) | |
| if progress is not None and (idx + 1 == total_tokens or (idx + 1) % progress_step == 0): | |
| progress((idx + 1, total_tokens), desc=progress_desc, unit="token") | |
| text = decode_tokens(tokenizer, decoded_tokens) | |
| duration = time.time() - start_time | |
| stats = { | |
| "tokens": total_tokens, | |
| "duration_s": duration, | |
| } | |
| return text, stats | |