Spaces:
Sleeping
Sleeping
| import base64 | |
| import inspect | |
| import os | |
| import shutil | |
| import tempfile | |
| import urllib.request | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from llm_compressor import compress_tokens, decompress_bytes, load_rwkv_model, tokenize_text | |
| MAX_INPUT_CHARS = 16384 | |
| SCRIPT_DIR = Path(__file__).parent.absolute() | |
| SUPPORT_DIR = SCRIPT_DIR / "support" | |
| MODELS_DIR = SCRIPT_DIR / "models" | |
| DEFAULT_MODEL_FILENAME = "rwkv7-g1a-0.1b-20250728-ctx4096.pth" | |
| DEFAULT_MODEL_PATH = str(MODELS_DIR / DEFAULT_MODEL_FILENAME) | |
| DEFAULT_MODEL_URL = "https://huggingface.co/BlinkDL/rwkv7-g1/resolve/main/" "rwkv7-g1a-0.1b-20250728-ctx4096.pth?download=true" | |
| DEFAULT_TOKENIZER_PATH = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt") | |
| def _patch_gradio_client_schema(): | |
| try: | |
| from gradio_client import utils as gr_client_utils | |
| except Exception: | |
| return | |
| if getattr(gr_client_utils, "_rwkv_patch", False): | |
| return | |
| original_get_type = gr_client_utils.get_type | |
| original_json_schema = gr_client_utils._json_schema_to_python_type | |
| def _patched_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return original_get_type(schema) | |
| gr_client_utils.get_type = _patched_get_type | |
| gr_client_utils._json_schema_to_python_type = lambda schema, defs=None: "Any" if isinstance(schema, bool) else original_json_schema(schema, defs) | |
| gr_client_utils._rwkv_patch = True | |
| _patch_gradio_client_schema() | |
| def _write_temp_file(data, suffix=".llmc"): | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| tmp.write(data) | |
| tmp.flush() | |
| tmp.close() | |
| return tmp.name | |
| def _resolve_default_model_path(): | |
| env_model = os.getenv("RWKV_MODEL_PATH") | |
| if env_model: | |
| return env_model | |
| default_path = Path(DEFAULT_MODEL_PATH) | |
| if default_path.is_file(): | |
| return str(default_path) | |
| if DEFAULT_MODEL_URL: | |
| downloaded = _download_default_model() | |
| if downloaded: | |
| return downloaded | |
| if MODELS_DIR.is_dir(): | |
| candidates = sorted(MODELS_DIR.glob("*.pth")) | |
| if candidates: | |
| return str(candidates[0]) | |
| return "" | |
| def _resolve_default_tokenizer_path(): | |
| env_tokenizer = os.getenv("RWKV_TOKENIZER") | |
| if env_tokenizer: | |
| return env_tokenizer | |
| default_path = Path(DEFAULT_TOKENIZER_PATH) | |
| if default_path.is_file(): | |
| return str(default_path) | |
| if SUPPORT_DIR.is_dir(): | |
| candidates = sorted(SUPPORT_DIR.glob("rwkv_vocab_v*.txt")) | |
| if candidates: | |
| return str(candidates[0]) | |
| return str(default_path) | |
| def _download_default_model(): | |
| if not DEFAULT_MODEL_URL: | |
| return "" | |
| dest_path = Path(DEFAULT_MODEL_PATH) | |
| if dest_path.is_file(): | |
| return str(dest_path) | |
| dest_path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp") | |
| try: | |
| print(f"Downloading RWKV model to {dest_path}...") | |
| with urllib.request.urlopen(DEFAULT_MODEL_URL) as response, open(tmp_path, "wb") as f: | |
| shutil.copyfileobj(response, f) | |
| tmp_path.replace(dest_path) | |
| return str(dest_path) | |
| except Exception as exc: | |
| if tmp_path.exists(): | |
| tmp_path.unlink() | |
| print(f"Failed to download RWKV model: {exc}") | |
| return "" | |
| def _resolve_model_path(value): | |
| if not value: | |
| return None | |
| path = Path(value).expanduser() | |
| candidates = [path] | |
| if path.suffix != ".pth": | |
| candidates.append(path.with_suffix(".pth")) | |
| if not path.is_absolute(): | |
| candidates.append(MODELS_DIR / path) | |
| if path.suffix != ".pth": | |
| candidates.append((MODELS_DIR / path).with_suffix(".pth")) | |
| for candidate in candidates: | |
| if candidate.is_file(): | |
| return candidate | |
| return None | |
| def _resolve_tokenizer_path(value): | |
| if not value: | |
| return None | |
| path = Path(value).expanduser() | |
| candidates = [path] | |
| if not path.is_absolute(): | |
| candidates.append(SUPPORT_DIR / path) | |
| for candidate in candidates: | |
| if candidate.is_file(): | |
| return candidate | |
| return None | |
| def _resolve_strategy(): | |
| return _normalize_strategy(os.getenv("RWKV_STRATEGY", "cpu fp32")) | |
| def _extract_file_bytes(file_data): | |
| if file_data is None: | |
| return None | |
| if isinstance(file_data, (bytes, bytearray)): | |
| return bytes(file_data) | |
| if isinstance(file_data, dict) and "data" in file_data: | |
| return file_data["data"] | |
| if isinstance(file_data, str): | |
| with open(file_data, "rb") as f: | |
| return f.read() | |
| if hasattr(file_data, "read"): | |
| return file_data.read() | |
| raise gr.Error("Unsupported uploaded file format.") | |
| def _get_compressed_bytes(b64_data, file_data): | |
| file_bytes = _extract_file_bytes(file_data) | |
| if file_bytes: | |
| return file_bytes | |
| if not b64_data or not b64_data.strip(): | |
| raise gr.Error("Compressed base64 data is empty.") | |
| try: | |
| return base64.b64decode(b64_data.encode("ascii"), validate=True) | |
| except Exception as exc: | |
| raise gr.Error(f"Invalid base64 data: {exc}") from exc | |
| def _load_model_and_tokenizer(model_path, tokenizer_name, strategy): | |
| resolved_model = _resolve_model_path(model_path) | |
| if not resolved_model: | |
| raise gr.Error(f"RWKV model file not found: {model_path}. Put a .pth in {MODELS_DIR} or set RWKV_MODEL_PATH.") | |
| resolved_tokenizer = _resolve_tokenizer_path(tokenizer_name) | |
| if not resolved_tokenizer: | |
| raise gr.Error(f"Tokenizer vocab file not found: {tokenizer_name}. Put rwkv_vocab_v20230424.txt in {SUPPORT_DIR} " "or set RWKV_TOKENIZER.") | |
| try: | |
| return load_rwkv_model(str(resolved_model), str(resolved_tokenizer), strategy) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to load RWKV model: {exc}") from exc | |
| def _format_compress_stats(stats, char_count=None): | |
| lines = [] | |
| if char_count is not None: | |
| lines.append(f"- Characters: {char_count}") | |
| lines.extend( | |
| [ | |
| f"- Tokens: {stats['tokens']}", | |
| f"- Original bytes: {stats['original_bytes']}", | |
| f"- Compressed bytes: {stats['compressed_bytes']}", | |
| f"- Compression ratio: {stats['ratio'] * 100:.2f}%", | |
| f"- Theoretical ratio: {stats['theoretical_ratio'] * 100:.2f}%", | |
| f"- Time: {stats['duration_s']:.2f}s", | |
| f"- Speed: {stats['speed_toks_per_s']:.2f} tokens/s", | |
| ] | |
| ) | |
| return "\n".join(lines) | |
| def _format_decompress_stats(stats, char_count=None): | |
| lines = [] | |
| if char_count is not None: | |
| lines.append(f"- Characters: {char_count}") | |
| lines.extend( | |
| [ | |
| f"- Tokens: {stats['tokens']}", | |
| f"- Time: {stats['duration_s']:.2f}s", | |
| ] | |
| ) | |
| return "\n".join(lines) | |
| def _normalize_strategy(strategy): | |
| if "cuda" in strategy and not torch.cuda.is_available(): | |
| return "cpu fp32" | |
| return strategy | |
| def _get_model_display_name(): | |
| env_model = os.getenv("RWKV_MODEL_PATH") | |
| if env_model: | |
| return Path(env_model).stem | |
| return Path(DEFAULT_MODEL_FILENAME).stem | |
| def compress_ui(text, context_window, progress=gr.Progress()): | |
| if not text or not text.strip(): | |
| raise gr.Error("Input text is empty.") | |
| if len(text) > MAX_INPUT_CHARS: | |
| message = f"Input is too long ({len(text)} chars). Max is {MAX_INPUT_CHARS}." | |
| gr.Info(message) | |
| return "", f"- {message}", None | |
| model_path = _resolve_default_model_path() | |
| tokenizer_path = _resolve_default_tokenizer_path() | |
| requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32") | |
| effective_strategy = _resolve_strategy() | |
| model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy) | |
| tokens = tokenize_text(tokenizer, text) | |
| if not tokens: | |
| raise gr.Error("Tokenized input is empty.") | |
| original_bytes = len(text.encode("utf-8")) | |
| data, stats = compress_tokens( | |
| tokens, | |
| model, | |
| context_window=context_window, | |
| original_bytes=original_bytes, | |
| progress=progress, | |
| progress_desc="Compressing", | |
| ) | |
| b64 = base64.b64encode(data).decode("ascii") | |
| file_path = _write_temp_file(data) | |
| stats_text = _format_compress_stats(stats, char_count=len(text)) | |
| if effective_strategy != requested_strategy: | |
| stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)" | |
| else: | |
| stats_text += f"\n- Strategy: {effective_strategy}" | |
| return b64, stats_text, file_path | |
| def decompress_ui(b64_data, file_data, context_window, progress=gr.Progress()): | |
| raw = _get_compressed_bytes(b64_data, file_data) | |
| model_path = _resolve_default_model_path() | |
| tokenizer_path = _resolve_default_tokenizer_path() | |
| requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32") | |
| effective_strategy = _resolve_strategy() | |
| model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy) | |
| text, stats = decompress_bytes( | |
| raw, | |
| model, | |
| tokenizer, | |
| context_window=context_window, | |
| progress=progress, | |
| progress_desc="Decompressing", | |
| ) | |
| stats_text = _format_decompress_stats(stats, char_count=len(text)) | |
| if effective_strategy != requested_strategy: | |
| stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)" | |
| else: | |
| stats_text += f"\n- Strategy: {effective_strategy}" | |
| return text, stats_text | |
| def build_ui(): | |
| model_display = _get_model_display_name() | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| f""" | |
| <div style="text-align: center; margin-bottom: 16px;"> | |
| <h1 style="margin-bottom: 8px;">LLM Text Compressor</h1> | |
| <p style="margin-bottom: 12px; color: #666;"> | |
| This is a proof-of-concept demo. Compression and decompression are slow, | |
| and the output is not portable across different environments. | |
| </p> | |
| <div style="display: flex; justify-content: center; align-items: center; gap: 10px; flex-wrap: wrap;"> | |
| <a href="https://github.com/Jellyfish042/uncheatable_eval" target="_blank" style="text-decoration: none;"> | |
| <img src="https://img.shields.io/badge/GitHub-Project-181717?logo=github" alt="GitHub Project"> | |
| </a> | |
| <a href="https://huggingface.co/spaces/Jellyfish042/UncheatableEval" target="_blank" style="text-decoration: none;"> | |
| <img src="https://img.shields.io/badge/%F0%9F%8F%86%20Leaderboard-Gradio-ff7c00" alt="Leaderboard"> | |
| </a> | |
| <a href="https://huggingface.co/spaces/Jellyfish042/Compression-Lens" target="_blank" style="text-decoration: none;"> | |
| <img src="https://img.shields.io/badge/%F0%9F%94%AC%20Compression--Lens-Visualization-blue" alt="Compression Lens"> | |
| </a> | |
| </div> | |
| <div style="margin-top: 10px; font-size: 0.95em; color: #444;"> | |
| Model: <code>{model_display}</code> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| # gr.Markdown("If CUDA is unavailable, the app forces the strategy to cpu fp32.") | |
| context_window = gr.Slider( | |
| label="Context window", | |
| minimum=128, | |
| maximum=4096, | |
| step=128, | |
| value=4096, | |
| ) | |
| gr.Markdown(f"Max input size: {MAX_INPUT_CHARS} characters.") | |
| with gr.Tabs(): | |
| with gr.Tab("Compress"): | |
| input_text = gr.Textbox(label="Input text", lines=10) | |
| compress_button = gr.Button("Compress") | |
| output_b64 = gr.Textbox(label="Compressed data (base64)", lines=6) | |
| compress_stats = gr.Markdown() | |
| output_file = gr.File(label="Download compressed file") | |
| compress_button.click( | |
| compress_ui, | |
| inputs=[input_text, context_window], | |
| outputs=[output_b64, compress_stats, output_file], | |
| ) | |
| with gr.Tab("Decompress"): | |
| input_b64 = gr.Textbox(label="Compressed data (base64)", lines=6) | |
| input_file = gr.File(label="Or upload compressed file", type="binary") | |
| decompress_button = gr.Button("Decompress") | |
| output_text = gr.Textbox(label="Decompressed text", lines=10) | |
| decompress_stats = gr.Markdown() | |
| decompress_button.click( | |
| decompress_ui, | |
| inputs=[input_b64, input_file, context_window], | |
| outputs=[output_text, decompress_stats], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| launch_kwargs = { | |
| "server_name": "0.0.0.0", | |
| "server_port": 7860, | |
| "share": False, | |
| } | |
| try: | |
| launch_params = inspect.signature(gr.Blocks.launch).parameters | |
| if "show_api" in launch_params: | |
| launch_kwargs["show_api"] = False | |
| except (TypeError, ValueError): | |
| pass | |
| build_ui().queue(max_size=16).launch(**launch_kwargs) | |