Jellyfish042's picture
update
68b02f7
"""
UncheatableEval Visualization - Hugging Face Space
Compare byte-level prediction performance between Qwen3-1.7B-Base and RWKV7-G1C-1.5B.
"""
import gc
import os
from pathlib import Path
import gradio as gr
import torch
# Detect device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IS_CPU = DEVICE == "cpu"
# Model configuration
QWEN_MODEL_ID = "Qwen/Qwen3-1.7B-Base"
RWKV_MODEL_URL = "https://huggingface.co/BlinkDL/rwkv7-g1/resolve/main/rwkv7-g1c-1.5b-20260110-ctx8192.pth"
RWKV_MODEL_FILENAME = "rwkv7-g1c-1.5b-20260110-ctx8192.pth"
# Get the directory where this script is located
SCRIPT_DIR = Path(__file__).parent.absolute()
MODELS_DIR = SCRIPT_DIR / "models"
SUPPORT_DIR = SCRIPT_DIR / "support"
# Text length limits
MAX_TEXT_LENGTH = 8192
MIN_TEXT_LENGTH = 1
# Global model cache
_qwen_model = None
_qwen_tokenizer = None
_rwkv_model = None
_rwkv_tokenizer = None
_rwkv_model_path = None
_stats_manager = None
# Precomputed example cache
_precomputed_html = None
_precomputed_text = None
PRECOMPUTED_DIR = SCRIPT_DIR / "precomputed"
def download_rwkv_model(progress=None):
"""Download RWKV7 model if not exists."""
from huggingface_hub import hf_hub_download
model_path = MODELS_DIR / RWKV_MODEL_FILENAME
if model_path.exists():
return str(model_path)
MODELS_DIR.mkdir(parents=True, exist_ok=True)
# Download from HuggingFace Hub
downloaded_path = hf_hub_download(
repo_id="BlinkDL/rwkv7-g1", filename=RWKV_MODEL_FILENAME, local_dir=str(MODELS_DIR), local_dir_use_symlinks=False
)
return downloaded_path
def load_qwen_model():
"""Load Qwen3-1.7B-Base model."""
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)
# Configure based on device
if IS_CPU:
model_kwargs = {"torch_dtype": torch.float32, "device_map": None, "trust_remote_code": True, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
else:
model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "auto", "trust_remote_code": True}
try:
model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, attn_implementation="flash_attention_2", **model_kwargs).eval()
except Exception:
model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
return model, tokenizer
def load_rwkv7_model(model_path: str):
"""Load RWKV7-G1C-1.5B model."""
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
# Set CUDA flag based on device
if IS_CPU:
os.environ["RWKV_CUDA_ON"] = "0"
else:
os.environ["RWKV_CUDA_ON"] = "1"
from rwkv.model import RWKV
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
# Use appropriate strategy for device
if IS_CPU:
strategy = "cpu fp32"
else:
strategy = "cuda fp16"
# RWKV library automatically adds .pth extension, so remove it if present
if model_path.endswith(".pth"):
model_path = model_path[:-4]
model = RWKV(model=model_path, strategy=strategy)
vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
tokenizer = TRIE_TOKENIZER(vocab_path)
return model, tokenizer
def validate_input(text: str) -> tuple[bool, str]:
"""Validate input text."""
if not text or not text.strip():
return False, "Please enter some text to analyze."
text = text.strip()
if len(text) < MIN_TEXT_LENGTH:
return False, f"Text is too short. Minimum {MIN_TEXT_LENGTH} characters required."
if len(text) > MAX_TEXT_LENGTH:
return False, f"Text is too long. Maximum {MAX_TEXT_LENGTH} characters allowed. Current: {len(text)}"
return True, text
def load_precomputed_example():
"""Load precomputed example visualization."""
global _precomputed_html, _precomputed_text
html_path = PRECOMPUTED_DIR / "example_visualization.html"
metadata_path = PRECOMPUTED_DIR / "example_metadata.json"
if html_path.exists() and metadata_path.exists():
import json
with open(html_path, "r", encoding="utf-8") as f:
_precomputed_html = f.read()
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
_precomputed_text = metadata.get("example_text", "")
print(f"Loaded precomputed example ({len(_precomputed_text)} chars)")
return True
else:
print("No precomputed example found. Run precompute_example.py first.")
return False
def initialize_models():
"""Initialize and cache both models at startup."""
global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _rwkv_model_path, _stats_manager
print("Initializing models...")
# Load precomputed example first
load_precomputed_example()
# Download RWKV model if needed
print("Checking RWKV7 model...")
_rwkv_model_path = download_rwkv_model()
# Load Qwen model
print("Loading Qwen3-1.7B-Base...")
_qwen_model, _qwen_tokenizer = load_qwen_model()
# Load RWKV7 model
print("Loading RWKV7-G1C-1.5B...")
_rwkv_model, _rwkv_tokenizer = load_rwkv7_model(_rwkv_model_path)
# Initialize stats manager
from core.inference_stats import InferenceStatsManager
_stats_manager = InferenceStatsManager()
print("Models loaded successfully!")
def wrap_html_in_iframe(html: str) -> str:
"""Wrap HTML in an iframe for Gradio display."""
# For srcdoc attribute, we only need to escape quotes
# The HTML entities inside (like &quot;, &#10;) should remain as-is
escaped = html.replace('"', "&quot;")
return f"""
<div style="width:100%;height:700px;border:1px solid #ddd;border-radius:8px;overflow:hidden;">
<iframe srcdoc="{escaped}"
style="width:100%;height:100%;border:none;"
sandbox="allow-scripts"></iframe>
</div>
"""
def run_evaluation(text: str, progress=gr.Progress()):
"""Run evaluation on both models and generate visualization."""
from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
from visualization.html_generator import generate_comparison_html
# Use cached models
global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _stats_manager
# Validate input
valid, result = validate_input(text)
if not valid:
raise gr.Error(result)
text = result # Use cleaned text
try:
# Get token counts for prediction first
qwen_inputs = _qwen_tokenizer(text, return_tensors="pt", add_special_tokens=False)
qwen_token_count = qwen_inputs["input_ids"].shape[-1]
qwen_predicted_time = _stats_manager.predict_time("qwen", qwen_token_count)
rwkv_tokenized = _rwkv_tokenizer.encode(text)
rwkv_token_count = len(rwkv_tokenized.ids if hasattr(rwkv_tokenized, "ids") else rwkv_tokenized)
rwkv_predicted_time = _stats_manager.predict_time("rwkv", rwkv_token_count)
# Step 1: Evaluate Qwen (using cached model)
if qwen_predicted_time is not None:
progress(0, desc=f"Evaluating with Qwen3... (estimated: {qwen_predicted_time:.1f}s)")
else:
progress(0, desc="Evaluating with Qwen3...")
result_qwen = evaluate_hf_single_sample(_qwen_model, _qwen_tokenizer, text, bos_mode="add_newline_token")
# Save stats and print comparison
_stats_manager.add_record("qwen", qwen_token_count, result_qwen["inference_time"])
if qwen_predicted_time is not None:
print(f"Qwen3 completed in {result_qwen['inference_time']:.2f}s (predicted: {qwen_predicted_time:.2f}s)")
else:
print(f"Qwen3 completed in {result_qwen['inference_time']:.2f}s")
# Step 2: Evaluate RWKV7 (using cached model)
if rwkv_predicted_time is not None:
progress(0, desc=f"Evaluating with RWKV7... (estimated: {rwkv_predicted_time:.1f}s)")
else:
progress(0, desc="Evaluating with RWKV7...")
result_rwkv = evaluate_rwkv7_single_sample(_rwkv_model, _rwkv_tokenizer, text)
# Save stats and print comparison
_stats_manager.add_record("rwkv", rwkv_token_count, result_rwkv["inference_time"])
if rwkv_predicted_time is not None:
print(f"RWKV7 completed in {result_rwkv['inference_time']:.2f}s (predicted: {rwkv_predicted_time:.2f}s)")
else:
print(f"RWKV7 completed in {result_rwkv['inference_time']:.2f}s")
# Step 3: Generate visualization
progress(0, desc="Generating visualization...")
html = generate_comparison_html(
text=text,
byte_losses_a=result_rwkv["byte_wise_losses"],
byte_losses_b=result_qwen["byte_wise_losses"],
model_a_name="RWKV7-G1C-1.5B",
model_b_name="Qwen3-1.7B-Base",
topk_predictions_a=result_rwkv["top5_predictions"],
topk_predictions_b=result_qwen["top5_predictions"],
tokenizer_a=result_rwkv["tokenizer"],
tokenizer_b=result_qwen["tokenizer"],
model_type_a="rwkv7",
model_type_b="hf",
)
# Wrap HTML for iframe display
wrapped_html = wrap_html_in_iframe(html)
return wrapped_html
except torch.cuda.OutOfMemoryError:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise gr.Error("GPU memory insufficient. Please try:\n" "1. Use shorter text\n" "2. Wait a moment and try again")
except Exception as e:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise gr.Error(f"Evaluation failed: {str(e)}")
def clear_inputs():
"""Clear all inputs and outputs."""
return "", None
def get_default_example():
"""Get the default example for display on page load."""
global _precomputed_html, _precomputed_text
if _precomputed_html and _precomputed_text:
wrapped_html = wrap_html_in_iframe(_precomputed_html)
return _precomputed_text, wrapped_html
else:
return "", None
# Build Gradio UI
with gr.Blocks(title="Compression-Lens: RWKV-7 vs Qwen3", theme=gr.themes.Soft()) as demo:
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1 style="margin-bottom: 10px;">🔬 Compression-Lens: RWKV-7 vs Qwen3 Byte-Level Comparison</h1>
<p style="margin-bottom: 15px; color: #666;">Compare the byte-level prediction performance between <strong>RWKV7-G1C-1.5B</strong> and <strong>Qwen3-1.7B-Base</strong>.</p>
<div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
<a href="https://github.com/Jellyfish042/uncheatable_eval" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/badge/GitHub-Project-181717?logo=github" alt="GitHub Project">
</a>
<a href="https://huggingface.co/spaces/Jellyfish042/UncheatableEval" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/badge/%F0%9F%8F%86%20Leaderboard-Gradio-ff7c00" alt="Leaderboard">
</a>
</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Input Text",
placeholder=f"Enter text to analyze (max {MAX_TEXT_LENGTH} characters)...",
lines=10,
max_lines=20,
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
run_btn = gr.Button("▶ Run Comparison", variant="primary")
gr.Markdown("---")
with gr.Row():
with gr.Column():
output_html = gr.HTML(label="Visualization")
# Event handlers
clear_btn.click(fn=clear_inputs, outputs=[text_input, output_html])
run_btn.click(fn=run_evaluation, inputs=[text_input], outputs=[output_html])
# Load default example on page load
demo.load(fn=get_default_example, outputs=[text_input, output_html])
if __name__ == "__main__":
# Initialize models before launching the app
initialize_models()
# Launch the Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)