LLM-Compressor / app.py
Jellyfish042's picture
update
d490809
import base64
import inspect
import os
import shutil
import tempfile
import urllib.request
from pathlib import Path
import gradio as gr
import torch
from llm_compressor import compress_tokens, decompress_bytes, load_rwkv_model, tokenize_text
MAX_INPUT_CHARS = 16384
SCRIPT_DIR = Path(__file__).parent.absolute()
SUPPORT_DIR = SCRIPT_DIR / "support"
MODELS_DIR = SCRIPT_DIR / "models"
DEFAULT_MODEL_FILENAME = "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
DEFAULT_MODEL_PATH = str(MODELS_DIR / DEFAULT_MODEL_FILENAME)
DEFAULT_MODEL_URL = "https://huggingface.co/BlinkDL/rwkv7-g1/resolve/main/" "rwkv7-g1a-0.1b-20250728-ctx4096.pth?download=true"
DEFAULT_TOKENIZER_PATH = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
def _patch_gradio_client_schema():
try:
from gradio_client import utils as gr_client_utils
except Exception:
return
if getattr(gr_client_utils, "_rwkv_patch", False):
return
original_get_type = gr_client_utils.get_type
original_json_schema = gr_client_utils._json_schema_to_python_type
def _patched_get_type(schema):
if isinstance(schema, bool):
return "Any"
return original_get_type(schema)
gr_client_utils.get_type = _patched_get_type
gr_client_utils._json_schema_to_python_type = lambda schema, defs=None: "Any" if isinstance(schema, bool) else original_json_schema(schema, defs)
gr_client_utils._rwkv_patch = True
_patch_gradio_client_schema()
def _write_temp_file(data, suffix=".llmc"):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
tmp.write(data)
tmp.flush()
tmp.close()
return tmp.name
def _resolve_default_model_path():
env_model = os.getenv("RWKV_MODEL_PATH")
if env_model:
return env_model
default_path = Path(DEFAULT_MODEL_PATH)
if default_path.is_file():
return str(default_path)
if DEFAULT_MODEL_URL:
downloaded = _download_default_model()
if downloaded:
return downloaded
if MODELS_DIR.is_dir():
candidates = sorted(MODELS_DIR.glob("*.pth"))
if candidates:
return str(candidates[0])
return ""
def _resolve_default_tokenizer_path():
env_tokenizer = os.getenv("RWKV_TOKENIZER")
if env_tokenizer:
return env_tokenizer
default_path = Path(DEFAULT_TOKENIZER_PATH)
if default_path.is_file():
return str(default_path)
if SUPPORT_DIR.is_dir():
candidates = sorted(SUPPORT_DIR.glob("rwkv_vocab_v*.txt"))
if candidates:
return str(candidates[0])
return str(default_path)
def _download_default_model():
if not DEFAULT_MODEL_URL:
return ""
dest_path = Path(DEFAULT_MODEL_PATH)
if dest_path.is_file():
return str(dest_path)
dest_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp")
try:
print(f"Downloading RWKV model to {dest_path}...")
with urllib.request.urlopen(DEFAULT_MODEL_URL) as response, open(tmp_path, "wb") as f:
shutil.copyfileobj(response, f)
tmp_path.replace(dest_path)
return str(dest_path)
except Exception as exc:
if tmp_path.exists():
tmp_path.unlink()
print(f"Failed to download RWKV model: {exc}")
return ""
def _resolve_model_path(value):
if not value:
return None
path = Path(value).expanduser()
candidates = [path]
if path.suffix != ".pth":
candidates.append(path.with_suffix(".pth"))
if not path.is_absolute():
candidates.append(MODELS_DIR / path)
if path.suffix != ".pth":
candidates.append((MODELS_DIR / path).with_suffix(".pth"))
for candidate in candidates:
if candidate.is_file():
return candidate
return None
def _resolve_tokenizer_path(value):
if not value:
return None
path = Path(value).expanduser()
candidates = [path]
if not path.is_absolute():
candidates.append(SUPPORT_DIR / path)
for candidate in candidates:
if candidate.is_file():
return candidate
return None
def _resolve_strategy():
return _normalize_strategy(os.getenv("RWKV_STRATEGY", "cpu fp32"))
def _extract_file_bytes(file_data):
if file_data is None:
return None
if isinstance(file_data, (bytes, bytearray)):
return bytes(file_data)
if isinstance(file_data, dict) and "data" in file_data:
return file_data["data"]
if isinstance(file_data, str):
with open(file_data, "rb") as f:
return f.read()
if hasattr(file_data, "read"):
return file_data.read()
raise gr.Error("Unsupported uploaded file format.")
def _get_compressed_bytes(b64_data, file_data):
file_bytes = _extract_file_bytes(file_data)
if file_bytes:
return file_bytes
if not b64_data or not b64_data.strip():
raise gr.Error("Compressed base64 data is empty.")
try:
return base64.b64decode(b64_data.encode("ascii"), validate=True)
except Exception as exc:
raise gr.Error(f"Invalid base64 data: {exc}") from exc
def _load_model_and_tokenizer(model_path, tokenizer_name, strategy):
resolved_model = _resolve_model_path(model_path)
if not resolved_model:
raise gr.Error(f"RWKV model file not found: {model_path}. Put a .pth in {MODELS_DIR} or set RWKV_MODEL_PATH.")
resolved_tokenizer = _resolve_tokenizer_path(tokenizer_name)
if not resolved_tokenizer:
raise gr.Error(f"Tokenizer vocab file not found: {tokenizer_name}. Put rwkv_vocab_v20230424.txt in {SUPPORT_DIR} " "or set RWKV_TOKENIZER.")
try:
return load_rwkv_model(str(resolved_model), str(resolved_tokenizer), strategy)
except Exception as exc:
raise gr.Error(f"Failed to load RWKV model: {exc}") from exc
def _format_compress_stats(stats, char_count=None):
lines = []
if char_count is not None:
lines.append(f"- Characters: {char_count}")
lines.extend(
[
f"- Tokens: {stats['tokens']}",
f"- Original bytes: {stats['original_bytes']}",
f"- Compressed bytes: {stats['compressed_bytes']}",
f"- Compression ratio: {stats['ratio'] * 100:.2f}%",
f"- Theoretical ratio: {stats['theoretical_ratio'] * 100:.2f}%",
f"- Time: {stats['duration_s']:.2f}s",
f"- Speed: {stats['speed_toks_per_s']:.2f} tokens/s",
]
)
return "\n".join(lines)
def _format_decompress_stats(stats, char_count=None):
lines = []
if char_count is not None:
lines.append(f"- Characters: {char_count}")
lines.extend(
[
f"- Tokens: {stats['tokens']}",
f"- Time: {stats['duration_s']:.2f}s",
]
)
return "\n".join(lines)
def _normalize_strategy(strategy):
if "cuda" in strategy and not torch.cuda.is_available():
return "cpu fp32"
return strategy
def _get_model_display_name():
env_model = os.getenv("RWKV_MODEL_PATH")
if env_model:
return Path(env_model).stem
return Path(DEFAULT_MODEL_FILENAME).stem
def compress_ui(text, context_window, progress=gr.Progress()):
if not text or not text.strip():
raise gr.Error("Input text is empty.")
if len(text) > MAX_INPUT_CHARS:
message = f"Input is too long ({len(text)} chars). Max is {MAX_INPUT_CHARS}."
gr.Info(message)
return "", f"- {message}", None
model_path = _resolve_default_model_path()
tokenizer_path = _resolve_default_tokenizer_path()
requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
effective_strategy = _resolve_strategy()
model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
tokens = tokenize_text(tokenizer, text)
if not tokens:
raise gr.Error("Tokenized input is empty.")
original_bytes = len(text.encode("utf-8"))
data, stats = compress_tokens(
tokens,
model,
context_window=context_window,
original_bytes=original_bytes,
progress=progress,
progress_desc="Compressing",
)
b64 = base64.b64encode(data).decode("ascii")
file_path = _write_temp_file(data)
stats_text = _format_compress_stats(stats, char_count=len(text))
if effective_strategy != requested_strategy:
stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
else:
stats_text += f"\n- Strategy: {effective_strategy}"
return b64, stats_text, file_path
def decompress_ui(b64_data, file_data, context_window, progress=gr.Progress()):
raw = _get_compressed_bytes(b64_data, file_data)
model_path = _resolve_default_model_path()
tokenizer_path = _resolve_default_tokenizer_path()
requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
effective_strategy = _resolve_strategy()
model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
text, stats = decompress_bytes(
raw,
model,
tokenizer,
context_window=context_window,
progress=progress,
progress_desc="Decompressing",
)
stats_text = _format_decompress_stats(stats, char_count=len(text))
if effective_strategy != requested_strategy:
stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
else:
stats_text += f"\n- Strategy: {effective_strategy}"
return text, stats_text
def build_ui():
model_display = _get_model_display_name()
with gr.Blocks() as demo:
gr.HTML(
f"""
<div style="text-align: center; margin-bottom: 16px;">
<h1 style="margin-bottom: 8px;">LLM Text Compressor</h1>
<p style="margin-bottom: 12px; color: #666;">
This is a proof-of-concept demo. Compression and decompression are slow,
and the output is not portable across different environments.
</p>
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; flex-wrap: wrap;">
<a href="https://github.com/Jellyfish042/uncheatable_eval" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/badge/GitHub-Project-181717?logo=github" alt="GitHub Project">
</a>
<a href="https://huggingface.co/spaces/Jellyfish042/UncheatableEval" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/badge/%F0%9F%8F%86%20Leaderboard-Gradio-ff7c00" alt="Leaderboard">
</a>
<a href="https://huggingface.co/spaces/Jellyfish042/Compression-Lens" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/badge/%F0%9F%94%AC%20Compression--Lens-Visualization-blue" alt="Compression Lens">
</a>
</div>
<div style="margin-top: 10px; font-size: 0.95em; color: #444;">
Model: <code>{model_display}</code>
</div>
</div>
"""
)
# gr.Markdown("If CUDA is unavailable, the app forces the strategy to cpu fp32.")
context_window = gr.Slider(
label="Context window",
minimum=128,
maximum=4096,
step=128,
value=4096,
)
gr.Markdown(f"Max input size: {MAX_INPUT_CHARS} characters.")
with gr.Tabs():
with gr.Tab("Compress"):
input_text = gr.Textbox(label="Input text", lines=10)
compress_button = gr.Button("Compress")
output_b64 = gr.Textbox(label="Compressed data (base64)", lines=6)
compress_stats = gr.Markdown()
output_file = gr.File(label="Download compressed file")
compress_button.click(
compress_ui,
inputs=[input_text, context_window],
outputs=[output_b64, compress_stats, output_file],
)
with gr.Tab("Decompress"):
input_b64 = gr.Textbox(label="Compressed data (base64)", lines=6)
input_file = gr.File(label="Or upload compressed file", type="binary")
decompress_button = gr.Button("Decompress")
output_text = gr.Textbox(label="Decompressed text", lines=10)
decompress_stats = gr.Markdown()
decompress_button.click(
decompress_ui,
inputs=[input_b64, input_file, context_window],
outputs=[output_text, decompress_stats],
)
return demo
if __name__ == "__main__":
launch_kwargs = {
"server_name": "0.0.0.0",
"server_port": 7860,
"share": False,
}
try:
launch_params = inspect.signature(gr.Blocks.launch).parameters
if "show_api" in launch_params:
launch_kwargs["show_api"] = False
except (TypeError, ValueError):
pass
build_ui().queue(max_size=16).launch(**launch_kwargs)