LLM-Compressor / llm_compressor.py
Jellyfish042's picture
Fix launch args and add decompression progress
eb04823
import io
import math
import os
import struct
import threading
import time
from functools import lru_cache
import torch
PROB_SCALE = 1 << 48
ARITHMETIC_PRECISION = 64
class BitOutputStream:
def __init__(self, file_obj):
self.file_obj = file_obj
self.byte = 0
self.bit_count = 0
def write_bit(self, bit):
self.byte = (self.byte << 1) | bit
self.bit_count += 1
if self.bit_count == 8:
self.file_obj.write(bytes([self.byte]))
self.byte = 0
self.bit_count = 0
def close(self):
if self.bit_count > 0:
self.byte <<= 8 - self.bit_count
self.file_obj.write(bytes([self.byte]))
class BitInputStream:
def __init__(self, file_obj):
self.file_obj = file_obj
self.byte = 0
self.bit_count = 0
def read_bit(self):
if self.bit_count == 0:
bytes_data = self.file_obj.read(1)
if not bytes_data:
return -1
self.byte = bytes_data[0]
self.bit_count = 8
bit = (self.byte >> (self.bit_count - 1)) & 1
self.bit_count -= 1
return bit
class ArithmeticEncoder:
def __init__(self, bit_output, precision=ARITHMETIC_PRECISION):
self.bit_output = bit_output
self.precision = precision
self.max_val = (1 << precision) - 1
self.quarter_val = 1 << (precision - 2)
self.half_val = 1 << (precision - 1)
self.three_quarter_val = self.quarter_val * 3
self.low = 0
self.high = self.max_val
self.pending_bits = 0
def encode_symbol(self, low_count, high_count, total_count):
range_val = self.high - self.low + 1
self.high = self.low + (range_val * high_count) // total_count - 1
self.low = self.low + (range_val * low_count) // total_count
while True:
if self.high < self.half_val:
self._write_bit(0)
elif self.low >= self.half_val:
self._write_bit(1)
self.low -= self.half_val
self.high -= self.half_val
elif self.low >= self.quarter_val and self.high < self.three_quarter_val:
self.pending_bits += 1
self.low -= self.quarter_val
self.high -= self.quarter_val
else:
break
self.low <<= 1
self.high = (self.high << 1) | 1
def _write_bit(self, bit):
self.bit_output.write_bit(bit)
while self.pending_bits > 0:
self.bit_output.write_bit(1 - bit)
self.pending_bits -= 1
def finish(self):
self.pending_bits += 1
if self.low < self.quarter_val:
self._write_bit(0)
else:
self._write_bit(1)
class ArithmeticDecoder:
def __init__(self, bit_input, precision=ARITHMETIC_PRECISION):
self.bit_input = bit_input
self.precision = precision
self.max_val = (1 << precision) - 1
self.quarter_val = 1 << (precision - 2)
self.half_val = 1 << (precision - 1)
self.three_quarter_val = self.quarter_val * 3
self.low = 0
self.high = self.max_val
self.value = 0
for _ in range(precision):
read_val = self.bit_input.read_bit()
if read_val == -1:
read_val = 0
self.value = (self.value << 1) | read_val
def decode_symbol_find_count(self, total_count):
range_val = self.high - self.low + 1
count = ((self.value - self.low + 1) * total_count - 1) // range_val
return count
def update_range(self, low_count, high_count, total_count):
range_val = self.high - self.low + 1
self.high = self.low + (range_val * high_count) // total_count - 1
self.low = self.low + (range_val * low_count) // total_count
while True:
if self.high < self.half_val:
pass
elif self.low >= self.half_val:
self.value -= self.half_val
self.low -= self.half_val
self.high -= self.half_val
elif self.low >= self.quarter_val and self.high < self.three_quarter_val:
self.value -= self.quarter_val
self.low -= self.quarter_val
self.high -= self.quarter_val
else:
break
self.low <<= 1
self.high = (self.high << 1) | 1
bit = self.bit_input.read_bit()
if bit == -1:
bit = 0
self.value = (self.value << 1) | bit
def _strip_pth(model_path):
return model_path[:-4] if model_path.endswith(".pth") else model_path
def _prepare_logits(logits):
if not isinstance(logits, torch.Tensor):
logits = torch.tensor(logits)
if logits.ndim > 1:
logits = logits[-1]
return logits.float()
def tokenize_text(tokenizer, text):
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
tokenized = tokenized.ids
return [int(token_id) for token_id in tokenized]
def decode_tokens(tokenizer, tokens):
return tokenizer.decode(tokens)
_MODEL_LOCK = threading.Lock()
@lru_cache(maxsize=2)
def load_rwkv_model(model_path, tokenizer_name, strategy):
if not model_path:
raise ValueError("RWKV model path is required.")
if not tokenizer_name:
raise ValueError("RWKV tokenizer name or path is required.")
if "cuda" in strategy and not torch.cuda.is_available():
strategy = "cpu fp32"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "1" if "cuda" in strategy else "0"
with _MODEL_LOCK:
from rwkv.model import RWKV
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
model = RWKV(model=_strip_pth(model_path), strategy=strategy)
tokenizer = TRIE_TOKENIZER(tokenizer_name)
return model, tokenizer
def compress_tokens(
tokens,
model,
context_window=2048,
original_bytes=None,
progress=None,
progress_desc="Compressing",
):
if context_window <= 0:
raise ValueError("context_window must be positive.")
token_ids = [int(token_id) for token_id in tokens]
if not token_ids:
raise ValueError("No tokens to compress.")
output = io.BytesIO()
output.write(struct.pack(">I", len(token_ids)))
bit_output = BitOutputStream(output)
encoder = ArithmeticEncoder(bit_output, precision=ARITHMETIC_PRECISION)
context_tokens = []
state = None
total_nll = 0.0
start_time = time.time()
total_tokens = len(token_ids)
if progress is not None:
progress((0, total_tokens), desc=progress_desc, unit="token")
with torch.inference_mode():
for idx, token_id in enumerate(token_ids):
if len(context_tokens) >= context_window:
context_tokens = []
state = None
input_token = context_tokens[-1] if context_tokens else 0
logits, state = model.forward([input_token], state)
next_logits = _prepare_logits(logits)
probs = torch.softmax(next_logits, dim=-1)
counts = (probs * PROB_SCALE).to(torch.long)
counts = torch.clamp(counts, min=1)
cdf = torch.cumsum(counts, dim=-1)
total_count = int(cdf[-1].item())
prob_val = probs[token_id]
total_nll += float((-torch.log(prob_val)).item())
low_val = int(cdf[token_id - 1].item()) if token_id > 0 else 0
high_val = int(cdf[token_id].item())
encoder.encode_symbol(low_val, high_val, total_count)
context_tokens.append(token_id)
if progress is not None:
progress((idx + 1, total_tokens), desc=progress_desc, unit="token")
encoder.finish()
bit_output.close()
data = output.getvalue()
end_time = time.time()
original_bytes = int(original_bytes or 0)
compressed_bytes = len(data)
ratio = compressed_bytes / original_bytes if original_bytes > 0 else 0.0
theoretical_bits = total_nll / math.log(2)
theoretical_bytes = theoretical_bits / 8
theoretical_ratio = theoretical_bytes / original_bytes if original_bytes > 0 else 0.0
duration = end_time - start_time
speed = len(token_ids) / duration if duration > 0 else 0.0
stats = {
"tokens": len(token_ids),
"original_bytes": original_bytes,
"compressed_bytes": compressed_bytes,
"ratio": ratio,
"theoretical_ratio": theoretical_ratio,
"duration_s": duration,
"speed_toks_per_s": speed,
}
return data, stats
def compress_text(text, model, tokenizer, context_window=2048):
tokens = tokenize_text(tokenizer, text)
original_bytes = len(text.encode("utf-8"))
return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes)
def decompress_bytes(
data,
model,
tokenizer,
context_window=2048,
progress=None,
progress_desc="Decompressing",
):
if context_window <= 0:
raise ValueError("context_window must be positive.")
if not data or len(data) < 4:
raise ValueError("Compressed data is empty or invalid.")
buffer = io.BytesIO(data)
total_tokens_bytes = buffer.read(4)
total_tokens = struct.unpack(">I", total_tokens_bytes)[0]
bit_input = BitInputStream(buffer)
decoder = ArithmeticDecoder(bit_input, precision=ARITHMETIC_PRECISION)
decoded_tokens = []
context_tokens = []
state = None
start_time = time.time()
if progress is not None:
progress((0, total_tokens), desc=progress_desc, unit="token")
progress_step = max(1, total_tokens // 100)
with torch.inference_mode():
for idx in range(total_tokens):
if len(context_tokens) >= context_window:
context_tokens = []
state = None
input_token = context_tokens[-1] if context_tokens else 0
logits, state = model.forward([input_token], state)
next_logits = _prepare_logits(logits)
probs = torch.softmax(next_logits, dim=-1)
counts = (probs * PROB_SCALE).to(torch.long)
counts = torch.clamp(counts, min=1)
cdf = torch.cumsum(counts, dim=-1)
total_count = int(cdf[-1].item())
count_val = decoder.decode_symbol_find_count(total_count)
count_val_tensor = torch.tensor(count_val, device=cdf.device)
target_token_id = int(torch.searchsorted(cdf, count_val_tensor, right=True).item())
decoded_tokens.append(target_token_id)
context_tokens.append(target_token_id)
low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0
high_val = int(cdf[target_token_id].item())
decoder.update_range(low_val, high_val, total_count)
if progress is not None and (idx + 1 == total_tokens or (idx + 1) % progress_step == 0):
progress((idx + 1, total_tokens), desc=progress_desc, unit="token")
text = decode_tokens(tokenizer, decoded_tokens)
duration = time.time() - start_time
stats = {
"tokens": total_tokens,
"duration_s": duration,
}
return text, stats