Spaces:
Running
Running
Commit ·
d68c16d
1
Parent(s): 49eb0e6
Optimize model loading with caching and improve performance
Browse files- Add global model cache to avoid reloading models on each evaluation
- Initialize both Qwen3 and RWKV7 models at startup
- Remove redundant memory cleanup between evaluations
- Simplify progress reporting with safe_progress helper
- Remove download button functionality for cleaner UI
- Add .gitignore to exclude model files and cache
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- .claude/settings.local.json +3 -1
- .gitignore +27 -0
- app.py +58 -84
.claude/settings.local.json
CHANGED
|
@@ -6,7 +6,9 @@
|
|
| 6 |
"Bash(git remote add:*)",
|
| 7 |
"Bash(git push:*)",
|
| 8 |
"Bash(git branch:*)",
|
| 9 |
-
"Bash(git commit -m \"$\\(cat <<''EOF''\nFix Gradio compatibility for HuggingFace Spaces\n\n- Upgrade gradio to >=5.0.0 to fix API schema bug\n- Add server_name and server_port to demo.launch\\(\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
|
|
|
|
|
|
|
| 10 |
]
|
| 11 |
}
|
| 12 |
}
|
|
|
|
| 6 |
"Bash(git remote add:*)",
|
| 7 |
"Bash(git push:*)",
|
| 8 |
"Bash(git branch:*)",
|
| 9 |
+
"Bash(git commit -m \"$\\(cat <<''EOF''\nFix Gradio compatibility for HuggingFace Spaces\n\n- Upgrade gradio to >=5.0.0 to fix API schema bug\n- Add server_name and server_port to demo.launch\\(\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")",
|
| 10 |
+
"Bash(git commit:*)",
|
| 11 |
+
"Bash(git reset:*)"
|
| 12 |
]
|
| 13 |
}
|
| 14 |
}
|
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
|
| 7 |
+
# Model files
|
| 8 |
+
models/
|
| 9 |
+
*.pth
|
| 10 |
+
*.bin
|
| 11 |
+
*.safetensors
|
| 12 |
+
|
| 13 |
+
# Virtual environment
|
| 14 |
+
venv/
|
| 15 |
+
env/
|
| 16 |
+
ENV/
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
|
| 22 |
+
# OS
|
| 23 |
+
.DS_Store
|
| 24 |
+
Thumbs.db
|
| 25 |
+
|
| 26 |
+
# Gradio
|
| 27 |
+
flagged/
|
app.py
CHANGED
|
@@ -6,7 +6,6 @@ Compare byte-level prediction performance between Qwen3-1.7B-Base and RWKV7-G1C-
|
|
| 6 |
|
| 7 |
import gc
|
| 8 |
import os
|
| 9 |
-
import tempfile
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
import gradio as gr
|
|
@@ -30,6 +29,13 @@ SUPPORT_DIR = SCRIPT_DIR / "support"
|
|
| 30 |
MAX_TEXT_LENGTH = 4000
|
| 31 |
MIN_TEXT_LENGTH = 10
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Example texts
|
| 34 |
EXAMPLE_NEWS = """The rapid advancement of artificial intelligence has sparked both excitement and concern among researchers worldwide. While AI systems demonstrate remarkable capabilities in language understanding and generation, questions remain about their potential impact on employment and society."""
|
| 35 |
|
|
@@ -56,9 +62,6 @@ def download_rwkv_model(progress=None):
|
|
| 56 |
|
| 57 |
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 58 |
|
| 59 |
-
if progress:
|
| 60 |
-
progress(0.1, desc="Downloading RWKV7 model...")
|
| 61 |
-
|
| 62 |
# Download from HuggingFace Hub
|
| 63 |
downloaded_path = hf_hub_download(
|
| 64 |
repo_id="BlinkDL/rwkv7-g1",
|
|
@@ -132,6 +135,10 @@ def load_rwkv7_model(model_path: str):
|
|
| 132 |
else:
|
| 133 |
strategy = "cuda fp16"
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
model = RWKV(model=model_path, strategy=strategy)
|
| 136 |
|
| 137 |
vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
|
|
@@ -156,6 +163,27 @@ def validate_input(text: str) -> tuple[bool, str]:
|
|
| 156 |
return True, text
|
| 157 |
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def wrap_html_in_iframe(html: str) -> str:
|
| 160 |
"""Wrap HTML in an iframe for Gradio display."""
|
| 161 |
escaped = html.replace('"', '"')
|
|
@@ -173,6 +201,9 @@ def run_evaluation(text: str, progress=gr.Progress()):
|
|
| 173 |
from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
|
| 174 |
from visualization.html_generator import generate_comparison_html
|
| 175 |
|
|
|
|
|
|
|
|
|
|
| 176 |
# Validate input
|
| 177 |
valid, result = validate_input(text)
|
| 178 |
if not valid:
|
|
@@ -180,52 +211,33 @@ def run_evaluation(text: str, progress=gr.Progress()):
|
|
| 180 |
|
| 181 |
text = result # Use cleaned text
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
progress(0.1, desc="Loading Qwen3-1.7B-Base...")
|
| 190 |
-
qwen_model, qwen_tokenizer = load_qwen_model()
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
result_qwen = evaluate_hf_single_sample(
|
| 195 |
-
|
| 196 |
-
|
| 197 |
text,
|
| 198 |
bos_mode="add_newline_token"
|
| 199 |
)
|
| 200 |
|
| 201 |
-
# Step
|
| 202 |
-
|
| 203 |
-
del qwen_model
|
| 204 |
-
if torch.cuda.is_available():
|
| 205 |
-
torch.cuda.empty_cache()
|
| 206 |
-
gc.collect()
|
| 207 |
-
|
| 208 |
-
# Step 5: Load RWKV7 model
|
| 209 |
-
progress(0.5, desc="Loading RWKV7-G1C-1.5B...")
|
| 210 |
-
rwkv_model, rwkv_tokenizer = load_rwkv7_model(rwkv_model_path)
|
| 211 |
-
|
| 212 |
-
# Step 6: Evaluate RWKV7
|
| 213 |
-
progress(0.7, desc="Evaluating with RWKV7...")
|
| 214 |
result_rwkv = evaluate_rwkv7_single_sample(
|
| 215 |
-
|
| 216 |
-
|
| 217 |
text
|
| 218 |
)
|
| 219 |
|
| 220 |
-
# Step 7: Free RWKV memory
|
| 221 |
-
progress(0.8, desc="Freeing memory...")
|
| 222 |
-
del rwkv_model
|
| 223 |
-
if torch.cuda.is_available():
|
| 224 |
-
torch.cuda.empty_cache()
|
| 225 |
-
gc.collect()
|
| 226 |
-
|
| 227 |
# Step 8: Generate visualization
|
| 228 |
-
|
| 229 |
html = generate_comparison_html(
|
| 230 |
text=text,
|
| 231 |
byte_losses_a=result_qwen["byte_wise_losses"],
|
|
@@ -243,11 +255,7 @@ def run_evaluation(text: str, progress=gr.Progress()):
|
|
| 243 |
# Wrap HTML for iframe display
|
| 244 |
wrapped_html = wrap_html_in_iframe(html)
|
| 245 |
|
| 246 |
-
|
| 247 |
-
global _last_html_content
|
| 248 |
-
_last_html_content = html
|
| 249 |
-
|
| 250 |
-
progress(1.0, desc="Done!")
|
| 251 |
|
| 252 |
return wrapped_html
|
| 253 |
|
|
@@ -272,10 +280,6 @@ def clear_inputs():
|
|
| 272 |
return "", None
|
| 273 |
|
| 274 |
|
| 275 |
-
# Global variable to store the last generated HTML for download
|
| 276 |
-
_last_html_content = None
|
| 277 |
-
|
| 278 |
-
|
| 279 |
# Build Gradio UI
|
| 280 |
with gr.Blocks(
|
| 281 |
title="UncheatableEval: Qwen3 vs RWKV7",
|
|
@@ -320,7 +324,6 @@ with gr.Blocks(
|
|
| 320 |
with gr.Row():
|
| 321 |
with gr.Column():
|
| 322 |
output_html = gr.HTML(label="Visualization")
|
| 323 |
-
download_file = gr.File(label="📥 Download HTML", visible=False)
|
| 324 |
|
| 325 |
# Event handlers
|
| 326 |
news_btn.click(fn=lambda: EXAMPLE_NEWS, outputs=[text_input])
|
|
@@ -332,45 +335,16 @@ with gr.Blocks(
|
|
| 332 |
outputs=[text_input, output_html]
|
| 333 |
)
|
| 334 |
|
| 335 |
-
def run_and_prepare_download(text, progress=gr.Progress()):
|
| 336 |
-
"""Run evaluation and prepare download file."""
|
| 337 |
-
wrapped_html = run_evaluation(text, progress)
|
| 338 |
-
|
| 339 |
-
# Save HTML for download
|
| 340 |
-
temp_file = tempfile.NamedTemporaryFile(
|
| 341 |
-
mode='w',
|
| 342 |
-
suffix='.html',
|
| 343 |
-
delete=False,
|
| 344 |
-
encoding='utf-8'
|
| 345 |
-
)
|
| 346 |
-
temp_file.write(_last_html_content)
|
| 347 |
-
temp_file.close()
|
| 348 |
-
|
| 349 |
-
return wrapped_html, temp_file.name
|
| 350 |
-
|
| 351 |
run_btn.click(
|
| 352 |
-
fn=
|
| 353 |
inputs=[text_input],
|
| 354 |
-
outputs=[output_html
|
| 355 |
)
|
| 356 |
|
| 357 |
-
gr.Markdown("""
|
| 358 |
-
---
|
| 359 |
-
### About
|
| 360 |
-
|
| 361 |
-
This tool uses [UncheatableEval](https://github.com/Jellyfish042/UncheatableEval) to compare
|
| 362 |
-
language model performance at the byte level.
|
| 363 |
-
|
| 364 |
-
**Models:**
|
| 365 |
-
- **Qwen3-1.7B-Base**: Transformer-based model from Alibaba
|
| 366 |
-
- **RWKV7-G1C-1.5B**: Linear attention model from RWKV team
|
| 367 |
-
|
| 368 |
-
**How it works:**
|
| 369 |
-
1. Both models predict each byte in the input text
|
| 370 |
-
2. Lower prediction loss = better compression = better understanding
|
| 371 |
-
3. The visualization shows where each model performs better or worse
|
| 372 |
-
""")
|
| 373 |
-
|
| 374 |
|
| 375 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
| 6 |
|
| 7 |
import gc
|
| 8 |
import os
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
import gradio as gr
|
|
|
|
| 29 |
MAX_TEXT_LENGTH = 4000
|
| 30 |
MIN_TEXT_LENGTH = 10
|
| 31 |
|
| 32 |
+
# Global model cache
|
| 33 |
+
_qwen_model = None
|
| 34 |
+
_qwen_tokenizer = None
|
| 35 |
+
_rwkv_model = None
|
| 36 |
+
_rwkv_tokenizer = None
|
| 37 |
+
_rwkv_model_path = None
|
| 38 |
+
|
| 39 |
# Example texts
|
| 40 |
EXAMPLE_NEWS = """The rapid advancement of artificial intelligence has sparked both excitement and concern among researchers worldwide. While AI systems demonstrate remarkable capabilities in language understanding and generation, questions remain about their potential impact on employment and society."""
|
| 41 |
|
|
|
|
| 62 |
|
| 63 |
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
# Download from HuggingFace Hub
|
| 66 |
downloaded_path = hf_hub_download(
|
| 67 |
repo_id="BlinkDL/rwkv7-g1",
|
|
|
|
| 135 |
else:
|
| 136 |
strategy = "cuda fp16"
|
| 137 |
|
| 138 |
+
# RWKV library automatically adds .pth extension, so remove it if present
|
| 139 |
+
if model_path.endswith('.pth'):
|
| 140 |
+
model_path = model_path[:-4]
|
| 141 |
+
|
| 142 |
model = RWKV(model=model_path, strategy=strategy)
|
| 143 |
|
| 144 |
vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
|
|
|
|
| 163 |
return True, text
|
| 164 |
|
| 165 |
|
| 166 |
+
def initialize_models():
|
| 167 |
+
"""Initialize and cache both models at startup."""
|
| 168 |
+
global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _rwkv_model_path
|
| 169 |
+
|
| 170 |
+
print("Initializing models...")
|
| 171 |
+
|
| 172 |
+
# Download RWKV model if needed
|
| 173 |
+
print("Checking RWKV7 model...")
|
| 174 |
+
_rwkv_model_path = download_rwkv_model()
|
| 175 |
+
|
| 176 |
+
# Load Qwen model
|
| 177 |
+
print("Loading Qwen3-1.7B-Base...")
|
| 178 |
+
_qwen_model, _qwen_tokenizer = load_qwen_model()
|
| 179 |
+
|
| 180 |
+
# Load RWKV7 model
|
| 181 |
+
print("Loading RWKV7-G1C-1.5B...")
|
| 182 |
+
_rwkv_model, _rwkv_tokenizer = load_rwkv7_model(_rwkv_model_path)
|
| 183 |
+
|
| 184 |
+
print("Models loaded successfully!")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
def wrap_html_in_iframe(html: str) -> str:
|
| 188 |
"""Wrap HTML in an iframe for Gradio display."""
|
| 189 |
escaped = html.replace('"', '"')
|
|
|
|
| 201 |
from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
|
| 202 |
from visualization.html_generator import generate_comparison_html
|
| 203 |
|
| 204 |
+
# Use cached models
|
| 205 |
+
global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer
|
| 206 |
+
|
| 207 |
# Validate input
|
| 208 |
valid, result = validate_input(text)
|
| 209 |
if not valid:
|
|
|
|
| 211 |
|
| 212 |
text = result # Use cleaned text
|
| 213 |
|
| 214 |
+
# Helper function to safely call progress
|
| 215 |
+
def safe_progress(value, desc):
|
| 216 |
+
try:
|
| 217 |
+
progress(value, desc=desc)
|
| 218 |
+
except:
|
| 219 |
+
pass
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
try:
|
| 222 |
+
# Step 1: Evaluate Qwen (using cached model)
|
| 223 |
+
safe_progress(0.2, "Evaluating with Qwen3...")
|
| 224 |
result_qwen = evaluate_hf_single_sample(
|
| 225 |
+
_qwen_model,
|
| 226 |
+
_qwen_tokenizer,
|
| 227 |
text,
|
| 228 |
bos_mode="add_newline_token"
|
| 229 |
)
|
| 230 |
|
| 231 |
+
# Step 2: Evaluate RWKV7 (using cached model)
|
| 232 |
+
safe_progress(0.6, "Evaluating with RWKV7...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
result_rwkv = evaluate_rwkv7_single_sample(
|
| 234 |
+
_rwkv_model,
|
| 235 |
+
_rwkv_tokenizer,
|
| 236 |
text
|
| 237 |
)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# Step 8: Generate visualization
|
| 240 |
+
safe_progress(0.9, "Generating visualization...")
|
| 241 |
html = generate_comparison_html(
|
| 242 |
text=text,
|
| 243 |
byte_losses_a=result_qwen["byte_wise_losses"],
|
|
|
|
| 255 |
# Wrap HTML for iframe display
|
| 256 |
wrapped_html = wrap_html_in_iframe(html)
|
| 257 |
|
| 258 |
+
safe_progress(1.0, "Done!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
return wrapped_html
|
| 261 |
|
|
|
|
| 280 |
return "", None
|
| 281 |
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
# Build Gradio UI
|
| 284 |
with gr.Blocks(
|
| 285 |
title="UncheatableEval: Qwen3 vs RWKV7",
|
|
|
|
| 324 |
with gr.Row():
|
| 325 |
with gr.Column():
|
| 326 |
output_html = gr.HTML(label="Visualization")
|
|
|
|
| 327 |
|
| 328 |
# Event handlers
|
| 329 |
news_btn.click(fn=lambda: EXAMPLE_NEWS, outputs=[text_input])
|
|
|
|
| 335 |
outputs=[text_input, output_html]
|
| 336 |
)
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
run_btn.click(
|
| 339 |
+
fn=run_evaluation,
|
| 340 |
inputs=[text_input],
|
| 341 |
+
outputs=[output_html]
|
| 342 |
)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
if __name__ == "__main__":
|
| 346 |
+
# Initialize models before launching the app
|
| 347 |
+
initialize_models()
|
| 348 |
+
|
| 349 |
+
# Launch the Gradio app
|
| 350 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|