Spaces:
Running
Running
| """ | |
| UncheatableEval Visualization - RWKV Model A vs Model B | |
| Compare byte-level prediction performance between two selectable RWKV models. | |
| Required candidate sizes are 0.1B / 0.4B / 1.5B. | |
| Models are loaded from local project directory first, and auto-downloaded when missing. | |
| """ | |
| import gc | |
| import os | |
| from pathlib import Path | |
| import re | |
| import unicodedata | |
| import gradio as gr | |
| import torch | |
| # Detect device | |
| # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DEVICE = "cpu" | |
| IS_CPU = DEVICE == "cpu" | |
| # Model configuration | |
| HF_REPO_ID = "BlinkDL/rwkv7-g1" | |
| REQUIRED_MODEL_SIZES = ["0.1b", "0.4b", "1.5b"] # TEMP: 2.9b disabled due to OOM | |
| PREFERRED_MODEL_FILENAMES = { | |
| "0.1b": "rwkv7-g1d-0.1b-20260129-ctx8192.pth", | |
| "0.4b": "rwkv7-g1d-0.4b-20260210-ctx8192.pth", | |
| "1.5b": "rwkv7-g1d-1.5b-20260212-ctx8192.pth", | |
| # "2.9b": "rwkv7-g1d-2.9b-20260131-ctx8192.pth", # TEMP: disabled due to OOM | |
| } | |
| DEFAULT_MODEL_A_SIZE = "1.5b" | |
| DEFAULT_MODEL_B_SIZE = "0.4b" | |
| # Get the directory where this script is located | |
| SCRIPT_DIR = Path(__file__).parent.absolute() | |
| MODELS_DIR = SCRIPT_DIR / "models" | |
| SUPPORT_DIR = SCRIPT_DIR / "support" | |
| # Text length limits | |
| MAX_TEXT_LENGTH = 16384 | |
| MIN_TEXT_LENGTH = 1 | |
| # Global model cache | |
| _rwkv_tokenizer = None | |
| _model_registry = {} # label -> {filename, path, display_name, size_b, model, size_key} | |
| _default_model_a_label = None | |
| _default_model_b_label = None | |
| _stats_manager = None | |
| # Precomputed example cache | |
| _precomputed_html = None | |
| _precomputed_text = None | |
| PRECOMPUTED_DIR = SCRIPT_DIR / "precomputed" | |
| def _parse_size_b(filename: str): | |
| match = re.search(r"-(\d+(?:\.\d+)?)b-", filename.lower()) | |
| if not match: | |
| return None | |
| try: | |
| return float(match.group(1)) | |
| except ValueError: | |
| return None | |
| def _display_name_from_filename(filename: str) -> str: | |
| size_b = _parse_size_b(filename) | |
| size_text = f"{size_b:.1f}B" if size_b is not None else "Unknown" | |
| family = "RWKV7" | |
| if "g1d" in filename.lower(): | |
| family = "RWKV7-G1D" | |
| elif "g1c" in filename.lower(): | |
| family = "RWKV7-G1C" | |
| return f"{family}-{size_text}" | |
| def _size_to_pattern(size_key: str) -> str: | |
| return f"rwkv7-g1d-{size_key}-*.pth" | |
| def _extract_date_token(filename: str): | |
| m = re.search(r"-(\d{8})-", filename) | |
| return m.group(1) if m else "00000000" | |
| def _pick_best_filename(filenames): | |
| if not filenames: | |
| return None | |
| return sorted(filenames, key=lambda x: (_extract_date_token(x), x))[-1] | |
| def _find_local_filename_for_size(size_key: str): | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| matches = [p.name for p in MODELS_DIR.glob(_size_to_pattern(size_key))] | |
| return _pick_best_filename(matches) | |
| def _list_repo_files(): | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| return api.list_repo_files(repo_id=HF_REPO_ID, repo_type="model") | |
| def _find_remote_filename_for_size(size_key: str, repo_files): | |
| pattern = re.compile(rf"^rwkv7-g1d-{re.escape(size_key)}-.*\.pth$", re.IGNORECASE) | |
| matches = [f for f in repo_files if pattern.match(f)] | |
| return _pick_best_filename(matches) | |
| def _ensure_model_file(size_key: str, repo_files_cache=None) -> str: | |
| """Ensure one specific size model exists in local models directory. | |
| Returns local absolute path. | |
| """ | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| preferred = PREFERRED_MODEL_FILENAMES.get(size_key) | |
| if preferred: | |
| preferred_path = MODELS_DIR / preferred | |
| if preferred_path.exists(): | |
| return str(preferred_path) | |
| local_filename = _find_local_filename_for_size(size_key) | |
| if local_filename: | |
| return str(MODELS_DIR / local_filename) | |
| if repo_files_cache is None: | |
| repo_files_cache = _list_repo_files() | |
| remote_filename = preferred | |
| if remote_filename is None or remote_filename not in repo_files_cache: | |
| remote_filename = _find_remote_filename_for_size(size_key, repo_files_cache) | |
| if not remote_filename: | |
| raise RuntimeError( | |
| f"Could not find remote RWKV file for size {size_key} in repo {HF_REPO_ID}." | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| print(f"Downloading missing model {remote_filename} from {HF_REPO_ID} ...") | |
| local_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=remote_filename, | |
| local_dir=str(MODELS_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| return str(Path(local_path).resolve()) | |
| def _build_candidate_specs(): | |
| """Build required model specs from fixed size list, auto-downloading missing files.""" | |
| repo_files_cache = None | |
| specs = [] | |
| for size_key in REQUIRED_MODEL_SIZES: | |
| try: | |
| model_path = _ensure_model_file(size_key, repo_files_cache=repo_files_cache) | |
| except Exception: | |
| if repo_files_cache is None: | |
| repo_files_cache = _list_repo_files() | |
| model_path = _ensure_model_file(size_key, repo_files_cache=repo_files_cache) | |
| p = Path(model_path) | |
| filename = p.name | |
| size_b = _parse_size_b(filename) | |
| display_name = _display_name_from_filename(filename) | |
| label = f"{display_name} ({filename})" | |
| specs.append( | |
| { | |
| "label": label, | |
| "filename": filename, | |
| "path": str(p), | |
| "display_name": display_name, | |
| "size_b": size_b, | |
| "size_key": size_key, | |
| } | |
| ) | |
| return specs | |
| def _pick_default_pair(specs): | |
| if not specs: | |
| return None, None | |
| by_size = {s["size_key"]: s for s in specs} | |
| model_a = by_size.get(DEFAULT_MODEL_A_SIZE) | |
| model_b = by_size.get(DEFAULT_MODEL_B_SIZE) | |
| if model_a is None: | |
| model_a = sorted( | |
| specs, | |
| key=lambda x: (x["size_b"] is None, x["size_b"] if x["size_b"] is not None else 1e9, x["filename"]), | |
| )[0] | |
| if model_b is None or model_b["filename"] == model_a["filename"]: | |
| candidates = [s for s in specs if s["filename"] != model_a["filename"]] | |
| model_b = candidates[0] if candidates else model_a | |
| return model_a, model_b | |
| def _load_rwkv_model(model_path: str): | |
| """Load a RWKV7 model from local path.""" | |
| os.environ["RWKV_JIT_ON"] = "1" | |
| os.environ["RWKV_V7_ON"] = "1" | |
| if IS_CPU: | |
| os.environ["RWKV_CUDA_ON"] = "0" | |
| else: | |
| os.environ["RWKV_CUDA_ON"] = "1" | |
| from rwkv.model import RWKV | |
| strategy = "cpu fp32" if IS_CPU else "cuda fp16" | |
| if model_path.endswith(".pth"): | |
| model_path = model_path[:-4] | |
| return RWKV(model=model_path, strategy=strategy) | |
| def _load_rwkv_tokenizer(): | |
| from rwkv.rwkv_tokenizer import TRIE_TOKENIZER | |
| vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt") | |
| return TRIE_TOKENIZER(vocab_path) | |
| def validate_input(text: str) -> tuple[bool, str]: | |
| """Validate input text.""" | |
| if not text or not text.strip(): | |
| return False, "Please enter some text to analyze." | |
| text = unicodedata.normalize("NFC", text).strip() | |
| if len(text) < MIN_TEXT_LENGTH: | |
| return False, f"Text is too short. Minimum {MIN_TEXT_LENGTH} characters required." | |
| if len(text) > MAX_TEXT_LENGTH: | |
| return False, f"Text is too long. Maximum {MAX_TEXT_LENGTH} characters allowed. Current: {len(text)}" | |
| return True, text | |
| def load_precomputed_example(): | |
| """Load precomputed example visualization.""" | |
| global _precomputed_html, _precomputed_text | |
| html_path = PRECOMPUTED_DIR / "example_visualization.html" | |
| metadata_path = PRECOMPUTED_DIR / "example_metadata.json" | |
| if html_path.exists() and metadata_path.exists(): | |
| import json | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| _precomputed_html = f.read() | |
| with open(metadata_path, "r", encoding="utf-8") as f: | |
| metadata = json.load(f) | |
| _precomputed_text = metadata.get("example_text", "") | |
| print(f"Loaded precomputed example ({len(_precomputed_text)} chars)") | |
| return True | |
| print("No precomputed example found. Run precompute_example.py first.") | |
| return False | |
| def initialize_models(): | |
| """Initialize and cache all required RWKV models at startup.""" | |
| global _rwkv_tokenizer, _model_registry, _default_model_a_label, _default_model_b_label, _stats_manager | |
| print("Initializing models...") | |
| load_precomputed_example() | |
| specs = _build_candidate_specs() | |
| default_a, default_b = _pick_default_pair(specs) | |
| print("Loading shared RWKV tokenizer...") | |
| _rwkv_tokenizer = _load_rwkv_tokenizer() | |
| _model_registry = {} | |
| for spec in specs: | |
| print(f"Loading {spec['display_name']} from {spec['filename']}...") | |
| model = _load_rwkv_model(spec["path"]) | |
| _model_registry[spec["label"]] = { | |
| "filename": spec["filename"], | |
| "path": spec["path"], | |
| "display_name": spec["display_name"], | |
| "size_b": spec["size_b"], | |
| "size_key": spec["size_key"], | |
| "model": model, | |
| } | |
| _default_model_a_label = default_a["label"] | |
| _default_model_b_label = default_b["label"] | |
| from core.inference_stats import InferenceStatsManager | |
| _stats_manager = InferenceStatsManager() | |
| print(f"Default Model A: {_default_model_a_label}") | |
| print(f"Default Model B: {_default_model_b_label}") | |
| print("All required models loaded successfully!") | |
| def get_model_dropdown_choices(): | |
| if _model_registry: | |
| choices = list(_model_registry.keys()) | |
| value_a = _default_model_a_label or (choices[0] if choices else None) | |
| value_b = _default_model_b_label or (choices[1] if len(choices) > 1 else value_a) | |
| return choices, value_a, value_b | |
| fallback_specs = [] | |
| for size_key in REQUIRED_MODEL_SIZES: | |
| preferred = PREFERRED_MODEL_FILENAMES.get(size_key) | |
| fname = preferred if preferred else f"rwkv7-g1d-{size_key}-*.pth" | |
| fallback_specs.append((size_key, fname)) | |
| choices = [f"RWKV7-G1D-{s.upper()} ({f})" for s, f in fallback_specs] | |
| value_a = choices[1] if len(choices) > 1 else (choices[0] if choices else None) | |
| value_b = choices[2] if len(choices) > 2 else (choices[0] if choices else None) | |
| return choices, value_a, value_b | |
| def wrap_html_in_iframe(html: str) -> str: | |
| """Wrap HTML in an iframe for Gradio display.""" | |
| escaped = html.replace('"', """) | |
| onload_js = ( | |
| "(function(f){" | |
| "function r(){try{var d=f.contentWindow.document;" | |
| "if(!d)return;var h=Math.max(d.body.scrollHeight,d.documentElement.scrollHeight);" | |
| "f.style.height=(h+2)+'px';}catch(e){}}" | |
| "r();setTimeout(r,50);setTimeout(r,200);" | |
| "})(this)" | |
| ) | |
| return f""" | |
| <div style="width:100%;border:1px solid #ddd;border-radius:8px;overflow:hidden;"> | |
| <iframe srcdoc="{escaped}" | |
| style="width:100%;border:none;height:400px;" | |
| sandbox="allow-scripts allow-same-origin" | |
| onload="{onload_js}"></iframe> | |
| </div> | |
| """ | |
| def run_evaluation(text: str, model_a_label: str, model_b_label: str, progress=gr.Progress()): | |
| """Run evaluation on selected RWKV Model A and Model B and generate visualization.""" | |
| from core.evaluator import evaluate_rwkv7_single_sample | |
| from visualization.html_generator import generate_comparison_html | |
| global _rwkv_tokenizer, _model_registry, _stats_manager | |
| if not model_a_label or model_a_label not in _model_registry: | |
| raise gr.Error("Please choose a valid Model A.") | |
| if not model_b_label or model_b_label not in _model_registry: | |
| raise gr.Error("Please choose a valid Model B.") | |
| if model_a_label == model_b_label: | |
| raise gr.Error("Model A and Model B must be different.") | |
| valid, result = validate_input(text) | |
| if not valid: | |
| raise gr.Error(result) | |
| text = result | |
| model_a_entry = _model_registry[model_a_label] | |
| model_b_entry = _model_registry[model_b_label] | |
| try: | |
| tokenized = _rwkv_tokenizer.encode(text) | |
| token_count = len(tokenized.ids if hasattr(tokenized, "ids") else tokenized) | |
| model_a_stats_key = f"rwkv::{model_a_entry['filename']}" | |
| model_b_stats_key = f"rwkv::{model_b_entry['filename']}" | |
| model_a_predicted_time = _stats_manager.predict_time(model_a_stats_key, token_count) | |
| model_b_predicted_time = _stats_manager.predict_time(model_b_stats_key, token_count) | |
| if model_a_predicted_time is not None: | |
| progress(0, desc=f"Evaluating Model A {model_a_entry['display_name']}... (estimated: {model_a_predicted_time:.1f}s)") | |
| else: | |
| progress(0, desc=f"Evaluating Model A {model_a_entry['display_name']}...") | |
| result_a = evaluate_rwkv7_single_sample(model_a_entry["model"], _rwkv_tokenizer, text) | |
| _stats_manager.add_record(model_a_stats_key, token_count, result_a["inference_time"]) | |
| if model_b_predicted_time is not None: | |
| progress(0, desc=f"Evaluating Model B {model_b_entry['display_name']}... (estimated: {model_b_predicted_time:.1f}s)") | |
| else: | |
| progress(0, desc=f"Evaluating Model B {model_b_entry['display_name']}...") | |
| result_b = evaluate_rwkv7_single_sample(model_b_entry["model"], _rwkv_tokenizer, text) | |
| _stats_manager.add_record(model_b_stats_key, token_count, result_b["inference_time"]) | |
| progress(0, desc="Generating visualization...") | |
| html = generate_comparison_html( | |
| text=text, | |
| byte_losses_a=result_a["byte_wise_losses"], | |
| byte_losses_b=result_b["byte_wise_losses"], | |
| model_a_name=model_a_entry["display_name"], | |
| model_b_name=model_b_entry["display_name"], | |
| topk_predictions_a=result_a["top5_predictions"], | |
| topk_predictions_b=result_b["top5_predictions"], | |
| tokenizer_a=_rwkv_tokenizer, | |
| tokenizer_b=_rwkv_tokenizer, | |
| model_type_a="rwkv7", | |
| model_type_b="rwkv7", | |
| default_delta_mode="absolute", | |
| ) | |
| return wrap_html_in_iframe(html) | |
| except torch.cuda.OutOfMemoryError: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| raise gr.Error("GPU memory insufficient. Please try:\n1. Use shorter text\n2. Wait a moment and try again") | |
| except Exception as e: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| raise gr.Error(f"Evaluation failed: {str(e)}") | |
| def clear_inputs(): | |
| """Clear all inputs and outputs.""" | |
| return "", None | |
| def get_default_example(): | |
| """Get default example text/html and dropdown updates.""" | |
| global _precomputed_html, _precomputed_text | |
| choices, value_a, value_b = get_model_dropdown_choices() | |
| dropdown_a_update = gr.update(choices=choices, value=value_a) | |
| dropdown_b_update = gr.update(choices=choices, value=value_b) | |
| if _precomputed_html and _precomputed_text: | |
| return _precomputed_text, wrap_html_in_iframe(_precomputed_html), dropdown_a_update, dropdown_b_update | |
| return "", None, dropdown_a_update, dropdown_b_update | |
| # Prepare model dropdown choices for UI construction | |
| _model_choices_for_ui, _default_a_for_ui, _default_b_for_ui = get_model_dropdown_choices() | |
| # Build Gradio UI | |
| with gr.Blocks( | |
| title="RWKV-ScaleLens", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| #input-text textarea { | |
| font-family: Consolas, 'Courier New', monospace; | |
| } | |
| .gr-accordion-content { | |
| max-height: none !important; | |
| height: auto !important; | |
| overflow: visible !important; | |
| } | |
| .gr-accordion-content > div { | |
| max-height: none !important; | |
| height: auto !important; | |
| overflow: visible !important; | |
| } | |
| .gr-accordion-content .prose, | |
| .gr-accordion-content .markdown, | |
| .gr-accordion-content .md { | |
| max-height: none !important; | |
| height: auto !important; | |
| overflow: visible !important; | |
| } | |
| #compression-metric .gr-accordion-content, | |
| #compression-metric .gr-accordion-content > div, | |
| #compression-metric .prose, | |
| #compression-metric .markdown, | |
| #compression-metric .md, | |
| #compression-metric * { | |
| max-height: none !important; | |
| overflow: visible !important; | |
| overflow-y: visible !important; | |
| overflow-x: visible !important; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1 style="margin-bottom: 10px;">RWKV-ScaleLens</h1> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| model_a_selector = gr.Dropdown( | |
| label="Model A", | |
| choices=_model_choices_for_ui, | |
| value=_default_a_for_ui, | |
| interactive=True, | |
| ) | |
| model_b_selector = gr.Dropdown( | |
| label="Model B", | |
| choices=_model_choices_for_ui, | |
| value=_default_b_for_ui, | |
| interactive=True, | |
| ) | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder=f"Enter text to analyze (max {MAX_TEXT_LENGTH} characters)...", | |
| lines=10, | |
| max_lines=20, | |
| elem_id="input-text", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| run_btn = gr.Button("Run Comparison", variant="primary") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_html = gr.HTML(label="Visualization") | |
| with gr.Accordion("How to calculate compression rate?", open=False, elem_id="compression-metric"): | |
| gr.Markdown( | |
| r""" | |
| The compression rate $R(t)$ represents the ratio of the compressed bitstream length to the original data size. It is derived from the model's negative log-likelihood loss $\mathcal{L}_{\text{NLL}}(t) = -\ln P(t)$: | |
| $$ | |
| R(t) = \frac{\mathcal{L}_{\text{NLL}}(t)}{\ln 2 \cdot 8 \cdot L(t)} \times 100\% | |
| $$ | |
| where $L(t)$ is the token length in bytes, and the factor $(\ln 2 \cdot 8)^{-1}$ normalizes the loss from nats to percentage of the original data size. | |
| **Example.** For a 1-byte token ($L=1$) with probability $P(t) = 0.5$: | |
| $$ | |
| R(t) = \frac{-\ln(0.5)}{\ln 2 \cdot 8 \cdot 1} \times 100\% = 12.5\% | |
| $$ | |
| """, | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ], | |
| ) | |
| clear_btn.click(fn=clear_inputs, outputs=[text_input, output_html]) | |
| run_btn.click(fn=run_evaluation, inputs=[text_input, model_a_selector, model_b_selector], outputs=[output_html]) | |
| demo.load(fn=get_default_example, outputs=[text_input, output_html, model_a_selector, model_b_selector]) | |
| if __name__ == "__main__": | |
| initialize_models() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |