RWKV-ScaleLens / app.py
Jellyfish042's picture
chore: swap default model A/B and precompute ordering
ffc40f8
"""
UncheatableEval Visualization - RWKV Model A vs Model B
Compare byte-level prediction performance between two selectable RWKV models.
Required candidate sizes are 0.1B / 0.4B / 1.5B.
Models are loaded from local project directory first, and auto-downloaded when missing.
"""
import gc
import os
from pathlib import Path
import re
import unicodedata
import gradio as gr
import torch
# Detect device
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cpu"
IS_CPU = DEVICE == "cpu"
# Model configuration
HF_REPO_ID = "BlinkDL/rwkv7-g1"
REQUIRED_MODEL_SIZES = ["0.1b", "0.4b", "1.5b"] # TEMP: 2.9b disabled due to OOM
PREFERRED_MODEL_FILENAMES = {
"0.1b": "rwkv7-g1d-0.1b-20260129-ctx8192.pth",
"0.4b": "rwkv7-g1d-0.4b-20260210-ctx8192.pth",
"1.5b": "rwkv7-g1d-1.5b-20260212-ctx8192.pth",
# "2.9b": "rwkv7-g1d-2.9b-20260131-ctx8192.pth", # TEMP: disabled due to OOM
}
DEFAULT_MODEL_A_SIZE = "1.5b"
DEFAULT_MODEL_B_SIZE = "0.4b"
# Get the directory where this script is located
SCRIPT_DIR = Path(__file__).parent.absolute()
MODELS_DIR = SCRIPT_DIR / "models"
SUPPORT_DIR = SCRIPT_DIR / "support"
# Text length limits
MAX_TEXT_LENGTH = 16384
MIN_TEXT_LENGTH = 1
# Global model cache
_rwkv_tokenizer = None
_model_registry = {} # label -> {filename, path, display_name, size_b, model, size_key}
_default_model_a_label = None
_default_model_b_label = None
_stats_manager = None
# Precomputed example cache
_precomputed_html = None
_precomputed_text = None
PRECOMPUTED_DIR = SCRIPT_DIR / "precomputed"
def _parse_size_b(filename: str):
match = re.search(r"-(\d+(?:\.\d+)?)b-", filename.lower())
if not match:
return None
try:
return float(match.group(1))
except ValueError:
return None
def _display_name_from_filename(filename: str) -> str:
size_b = _parse_size_b(filename)
size_text = f"{size_b:.1f}B" if size_b is not None else "Unknown"
family = "RWKV7"
if "g1d" in filename.lower():
family = "RWKV7-G1D"
elif "g1c" in filename.lower():
family = "RWKV7-G1C"
return f"{family}-{size_text}"
def _size_to_pattern(size_key: str) -> str:
return f"rwkv7-g1d-{size_key}-*.pth"
def _extract_date_token(filename: str):
m = re.search(r"-(\d{8})-", filename)
return m.group(1) if m else "00000000"
def _pick_best_filename(filenames):
if not filenames:
return None
return sorted(filenames, key=lambda x: (_extract_date_token(x), x))[-1]
def _find_local_filename_for_size(size_key: str):
MODELS_DIR.mkdir(parents=True, exist_ok=True)
matches = [p.name for p in MODELS_DIR.glob(_size_to_pattern(size_key))]
return _pick_best_filename(matches)
def _list_repo_files():
from huggingface_hub import HfApi
api = HfApi()
return api.list_repo_files(repo_id=HF_REPO_ID, repo_type="model")
def _find_remote_filename_for_size(size_key: str, repo_files):
pattern = re.compile(rf"^rwkv7-g1d-{re.escape(size_key)}-.*\.pth$", re.IGNORECASE)
matches = [f for f in repo_files if pattern.match(f)]
return _pick_best_filename(matches)
def _ensure_model_file(size_key: str, repo_files_cache=None) -> str:
"""Ensure one specific size model exists in local models directory.
Returns local absolute path.
"""
MODELS_DIR.mkdir(parents=True, exist_ok=True)
preferred = PREFERRED_MODEL_FILENAMES.get(size_key)
if preferred:
preferred_path = MODELS_DIR / preferred
if preferred_path.exists():
return str(preferred_path)
local_filename = _find_local_filename_for_size(size_key)
if local_filename:
return str(MODELS_DIR / local_filename)
if repo_files_cache is None:
repo_files_cache = _list_repo_files()
remote_filename = preferred
if remote_filename is None or remote_filename not in repo_files_cache:
remote_filename = _find_remote_filename_for_size(size_key, repo_files_cache)
if not remote_filename:
raise RuntimeError(
f"Could not find remote RWKV file for size {size_key} in repo {HF_REPO_ID}."
)
from huggingface_hub import hf_hub_download
print(f"Downloading missing model {remote_filename} from {HF_REPO_ID} ...")
local_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=remote_filename,
local_dir=str(MODELS_DIR),
local_dir_use_symlinks=False,
)
return str(Path(local_path).resolve())
def _build_candidate_specs():
"""Build required model specs from fixed size list, auto-downloading missing files."""
repo_files_cache = None
specs = []
for size_key in REQUIRED_MODEL_SIZES:
try:
model_path = _ensure_model_file(size_key, repo_files_cache=repo_files_cache)
except Exception:
if repo_files_cache is None:
repo_files_cache = _list_repo_files()
model_path = _ensure_model_file(size_key, repo_files_cache=repo_files_cache)
p = Path(model_path)
filename = p.name
size_b = _parse_size_b(filename)
display_name = _display_name_from_filename(filename)
label = f"{display_name} ({filename})"
specs.append(
{
"label": label,
"filename": filename,
"path": str(p),
"display_name": display_name,
"size_b": size_b,
"size_key": size_key,
}
)
return specs
def _pick_default_pair(specs):
if not specs:
return None, None
by_size = {s["size_key"]: s for s in specs}
model_a = by_size.get(DEFAULT_MODEL_A_SIZE)
model_b = by_size.get(DEFAULT_MODEL_B_SIZE)
if model_a is None:
model_a = sorted(
specs,
key=lambda x: (x["size_b"] is None, x["size_b"] if x["size_b"] is not None else 1e9, x["filename"]),
)[0]
if model_b is None or model_b["filename"] == model_a["filename"]:
candidates = [s for s in specs if s["filename"] != model_a["filename"]]
model_b = candidates[0] if candidates else model_a
return model_a, model_b
def _load_rwkv_model(model_path: str):
"""Load a RWKV7 model from local path."""
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
if IS_CPU:
os.environ["RWKV_CUDA_ON"] = "0"
else:
os.environ["RWKV_CUDA_ON"] = "1"
from rwkv.model import RWKV
strategy = "cpu fp32" if IS_CPU else "cuda fp16"
if model_path.endswith(".pth"):
model_path = model_path[:-4]
return RWKV(model=model_path, strategy=strategy)
def _load_rwkv_tokenizer():
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
return TRIE_TOKENIZER(vocab_path)
def validate_input(text: str) -> tuple[bool, str]:
"""Validate input text."""
if not text or not text.strip():
return False, "Please enter some text to analyze."
text = unicodedata.normalize("NFC", text).strip()
if len(text) < MIN_TEXT_LENGTH:
return False, f"Text is too short. Minimum {MIN_TEXT_LENGTH} characters required."
if len(text) > MAX_TEXT_LENGTH:
return False, f"Text is too long. Maximum {MAX_TEXT_LENGTH} characters allowed. Current: {len(text)}"
return True, text
def load_precomputed_example():
"""Load precomputed example visualization."""
global _precomputed_html, _precomputed_text
html_path = PRECOMPUTED_DIR / "example_visualization.html"
metadata_path = PRECOMPUTED_DIR / "example_metadata.json"
if html_path.exists() and metadata_path.exists():
import json
with open(html_path, "r", encoding="utf-8") as f:
_precomputed_html = f.read()
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
_precomputed_text = metadata.get("example_text", "")
print(f"Loaded precomputed example ({len(_precomputed_text)} chars)")
return True
print("No precomputed example found. Run precompute_example.py first.")
return False
def initialize_models():
"""Initialize and cache all required RWKV models at startup."""
global _rwkv_tokenizer, _model_registry, _default_model_a_label, _default_model_b_label, _stats_manager
print("Initializing models...")
load_precomputed_example()
specs = _build_candidate_specs()
default_a, default_b = _pick_default_pair(specs)
print("Loading shared RWKV tokenizer...")
_rwkv_tokenizer = _load_rwkv_tokenizer()
_model_registry = {}
for spec in specs:
print(f"Loading {spec['display_name']} from {spec['filename']}...")
model = _load_rwkv_model(spec["path"])
_model_registry[spec["label"]] = {
"filename": spec["filename"],
"path": spec["path"],
"display_name": spec["display_name"],
"size_b": spec["size_b"],
"size_key": spec["size_key"],
"model": model,
}
_default_model_a_label = default_a["label"]
_default_model_b_label = default_b["label"]
from core.inference_stats import InferenceStatsManager
_stats_manager = InferenceStatsManager()
print(f"Default Model A: {_default_model_a_label}")
print(f"Default Model B: {_default_model_b_label}")
print("All required models loaded successfully!")
def get_model_dropdown_choices():
if _model_registry:
choices = list(_model_registry.keys())
value_a = _default_model_a_label or (choices[0] if choices else None)
value_b = _default_model_b_label or (choices[1] if len(choices) > 1 else value_a)
return choices, value_a, value_b
fallback_specs = []
for size_key in REQUIRED_MODEL_SIZES:
preferred = PREFERRED_MODEL_FILENAMES.get(size_key)
fname = preferred if preferred else f"rwkv7-g1d-{size_key}-*.pth"
fallback_specs.append((size_key, fname))
choices = [f"RWKV7-G1D-{s.upper()} ({f})" for s, f in fallback_specs]
value_a = choices[1] if len(choices) > 1 else (choices[0] if choices else None)
value_b = choices[2] if len(choices) > 2 else (choices[0] if choices else None)
return choices, value_a, value_b
def wrap_html_in_iframe(html: str) -> str:
"""Wrap HTML in an iframe for Gradio display."""
escaped = html.replace('"', "&quot;")
onload_js = (
"(function(f){"
"function r(){try{var d=f.contentWindow.document;"
"if(!d)return;var h=Math.max(d.body.scrollHeight,d.documentElement.scrollHeight);"
"f.style.height=(h+2)+'px';}catch(e){}}"
"r();setTimeout(r,50);setTimeout(r,200);"
"})(this)"
)
return f"""
<div style="width:100%;border:1px solid #ddd;border-radius:8px;overflow:hidden;">
<iframe srcdoc="{escaped}"
style="width:100%;border:none;height:400px;"
sandbox="allow-scripts allow-same-origin"
onload="{onload_js}"></iframe>
</div>
"""
def run_evaluation(text: str, model_a_label: str, model_b_label: str, progress=gr.Progress()):
"""Run evaluation on selected RWKV Model A and Model B and generate visualization."""
from core.evaluator import evaluate_rwkv7_single_sample
from visualization.html_generator import generate_comparison_html
global _rwkv_tokenizer, _model_registry, _stats_manager
if not model_a_label or model_a_label not in _model_registry:
raise gr.Error("Please choose a valid Model A.")
if not model_b_label or model_b_label not in _model_registry:
raise gr.Error("Please choose a valid Model B.")
if model_a_label == model_b_label:
raise gr.Error("Model A and Model B must be different.")
valid, result = validate_input(text)
if not valid:
raise gr.Error(result)
text = result
model_a_entry = _model_registry[model_a_label]
model_b_entry = _model_registry[model_b_label]
try:
tokenized = _rwkv_tokenizer.encode(text)
token_count = len(tokenized.ids if hasattr(tokenized, "ids") else tokenized)
model_a_stats_key = f"rwkv::{model_a_entry['filename']}"
model_b_stats_key = f"rwkv::{model_b_entry['filename']}"
model_a_predicted_time = _stats_manager.predict_time(model_a_stats_key, token_count)
model_b_predicted_time = _stats_manager.predict_time(model_b_stats_key, token_count)
if model_a_predicted_time is not None:
progress(0, desc=f"Evaluating Model A {model_a_entry['display_name']}... (estimated: {model_a_predicted_time:.1f}s)")
else:
progress(0, desc=f"Evaluating Model A {model_a_entry['display_name']}...")
result_a = evaluate_rwkv7_single_sample(model_a_entry["model"], _rwkv_tokenizer, text)
_stats_manager.add_record(model_a_stats_key, token_count, result_a["inference_time"])
if model_b_predicted_time is not None:
progress(0, desc=f"Evaluating Model B {model_b_entry['display_name']}... (estimated: {model_b_predicted_time:.1f}s)")
else:
progress(0, desc=f"Evaluating Model B {model_b_entry['display_name']}...")
result_b = evaluate_rwkv7_single_sample(model_b_entry["model"], _rwkv_tokenizer, text)
_stats_manager.add_record(model_b_stats_key, token_count, result_b["inference_time"])
progress(0, desc="Generating visualization...")
html = generate_comparison_html(
text=text,
byte_losses_a=result_a["byte_wise_losses"],
byte_losses_b=result_b["byte_wise_losses"],
model_a_name=model_a_entry["display_name"],
model_b_name=model_b_entry["display_name"],
topk_predictions_a=result_a["top5_predictions"],
topk_predictions_b=result_b["top5_predictions"],
tokenizer_a=_rwkv_tokenizer,
tokenizer_b=_rwkv_tokenizer,
model_type_a="rwkv7",
model_type_b="rwkv7",
default_delta_mode="absolute",
)
return wrap_html_in_iframe(html)
except torch.cuda.OutOfMemoryError:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise gr.Error("GPU memory insufficient. Please try:\n1. Use shorter text\n2. Wait a moment and try again")
except Exception as e:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise gr.Error(f"Evaluation failed: {str(e)}")
def clear_inputs():
"""Clear all inputs and outputs."""
return "", None
def get_default_example():
"""Get default example text/html and dropdown updates."""
global _precomputed_html, _precomputed_text
choices, value_a, value_b = get_model_dropdown_choices()
dropdown_a_update = gr.update(choices=choices, value=value_a)
dropdown_b_update = gr.update(choices=choices, value=value_b)
if _precomputed_html and _precomputed_text:
return _precomputed_text, wrap_html_in_iframe(_precomputed_html), dropdown_a_update, dropdown_b_update
return "", None, dropdown_a_update, dropdown_b_update
# Prepare model dropdown choices for UI construction
_model_choices_for_ui, _default_a_for_ui, _default_b_for_ui = get_model_dropdown_choices()
# Build Gradio UI
with gr.Blocks(
title="RWKV-ScaleLens",
theme=gr.themes.Soft(),
css="""
#input-text textarea {
font-family: Consolas, 'Courier New', monospace;
}
.gr-accordion-content {
max-height: none !important;
height: auto !important;
overflow: visible !important;
}
.gr-accordion-content > div {
max-height: none !important;
height: auto !important;
overflow: visible !important;
}
.gr-accordion-content .prose,
.gr-accordion-content .markdown,
.gr-accordion-content .md {
max-height: none !important;
height: auto !important;
overflow: visible !important;
}
#compression-metric .gr-accordion-content,
#compression-metric .gr-accordion-content > div,
#compression-metric .prose,
#compression-metric .markdown,
#compression-metric .md,
#compression-metric * {
max-height: none !important;
overflow: visible !important;
overflow-y: visible !important;
overflow-x: visible !important;
}
""",
) as demo:
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1 style="margin-bottom: 10px;">RWKV-ScaleLens</h1>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
model_a_selector = gr.Dropdown(
label="Model A",
choices=_model_choices_for_ui,
value=_default_a_for_ui,
interactive=True,
)
model_b_selector = gr.Dropdown(
label="Model B",
choices=_model_choices_for_ui,
value=_default_b_for_ui,
interactive=True,
)
text_input = gr.Textbox(
label="Input Text",
placeholder=f"Enter text to analyze (max {MAX_TEXT_LENGTH} characters)...",
lines=10,
max_lines=20,
elem_id="input-text",
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
run_btn = gr.Button("Run Comparison", variant="primary")
gr.Markdown("---")
with gr.Row():
with gr.Column():
output_html = gr.HTML(label="Visualization")
with gr.Accordion("How to calculate compression rate?", open=False, elem_id="compression-metric"):
gr.Markdown(
r"""
The compression rate $R(t)$ represents the ratio of the compressed bitstream length to the original data size. It is derived from the model's negative log-likelihood loss $\mathcal{L}_{\text{NLL}}(t) = -\ln P(t)$:
$$
R(t) = \frac{\mathcal{L}_{\text{NLL}}(t)}{\ln 2 \cdot 8 \cdot L(t)} \times 100\%
$$
where $L(t)$ is the token length in bytes, and the factor $(\ln 2 \cdot 8)^{-1}$ normalizes the loss from nats to percentage of the original data size.
**Example.** For a 1-byte token ($L=1$) with probability $P(t) = 0.5$:
$$
R(t) = \frac{-\ln(0.5)}{\ln 2 \cdot 8 \cdot 1} \times 100\% = 12.5\%
$$
""",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
)
clear_btn.click(fn=clear_inputs, outputs=[text_input, output_html])
run_btn.click(fn=run_evaluation, inputs=[text_input, model_a_selector, model_b_selector], outputs=[output_html])
demo.load(fn=get_default_example, outputs=[text_input, output_html, model_a_selector, model_b_selector])
if __name__ == "__main__":
initialize_models()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)