Jellyfish042 Claude Sonnet 4.5 commited on
Commit
d68c16d
·
1 Parent(s): 49eb0e6

Optimize model loading with caching and improve performance

Browse files

- Add global model cache to avoid reloading models on each evaluation
- Initialize both Qwen3 and RWKV7 models at startup
- Remove redundant memory cleanup between evaluations
- Simplify progress reporting with safe_progress helper
- Remove download button functionality for cleaner UI
- Add .gitignore to exclude model files and cache

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (3) hide show
  1. .claude/settings.local.json +3 -1
  2. .gitignore +27 -0
  3. app.py +58 -84
.claude/settings.local.json CHANGED
@@ -6,7 +6,9 @@
6
  "Bash(git remote add:*)",
7
  "Bash(git push:*)",
8
  "Bash(git branch:*)",
9
- "Bash(git commit -m \"$\\(cat <<''EOF''\nFix Gradio compatibility for HuggingFace Spaces\n\n- Upgrade gradio to >=5.0.0 to fix API schema bug\n- Add server_name and server_port to demo.launch\\(\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
 
 
10
  ]
11
  }
12
  }
 
6
  "Bash(git remote add:*)",
7
  "Bash(git push:*)",
8
  "Bash(git branch:*)",
9
+ "Bash(git commit -m \"$\\(cat <<''EOF''\nFix Gradio compatibility for HuggingFace Spaces\n\n- Upgrade gradio to >=5.0.0 to fix API schema bug\n- Add server_name and server_port to demo.launch\\(\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")",
10
+ "Bash(git commit:*)",
11
+ "Bash(git reset:*)"
12
  ]
13
  }
14
  }
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Model files
8
+ models/
9
+ *.pth
10
+ *.bin
11
+ *.safetensors
12
+
13
+ # Virtual environment
14
+ venv/
15
+ env/
16
+ ENV/
17
+
18
+ # IDE
19
+ .vscode/
20
+ .idea/
21
+
22
+ # OS
23
+ .DS_Store
24
+ Thumbs.db
25
+
26
+ # Gradio
27
+ flagged/
app.py CHANGED
@@ -6,7 +6,6 @@ Compare byte-level prediction performance between Qwen3-1.7B-Base and RWKV7-G1C-
6
 
7
  import gc
8
  import os
9
- import tempfile
10
  from pathlib import Path
11
 
12
  import gradio as gr
@@ -30,6 +29,13 @@ SUPPORT_DIR = SCRIPT_DIR / "support"
30
  MAX_TEXT_LENGTH = 4000
31
  MIN_TEXT_LENGTH = 10
32
 
 
 
 
 
 
 
 
33
  # Example texts
34
  EXAMPLE_NEWS = """The rapid advancement of artificial intelligence has sparked both excitement and concern among researchers worldwide. While AI systems demonstrate remarkable capabilities in language understanding and generation, questions remain about their potential impact on employment and society."""
35
 
@@ -56,9 +62,6 @@ def download_rwkv_model(progress=None):
56
 
57
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
58
 
59
- if progress:
60
- progress(0.1, desc="Downloading RWKV7 model...")
61
-
62
  # Download from HuggingFace Hub
63
  downloaded_path = hf_hub_download(
64
  repo_id="BlinkDL/rwkv7-g1",
@@ -132,6 +135,10 @@ def load_rwkv7_model(model_path: str):
132
  else:
133
  strategy = "cuda fp16"
134
 
 
 
 
 
135
  model = RWKV(model=model_path, strategy=strategy)
136
 
137
  vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
@@ -156,6 +163,27 @@ def validate_input(text: str) -> tuple[bool, str]:
156
  return True, text
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def wrap_html_in_iframe(html: str) -> str:
160
  """Wrap HTML in an iframe for Gradio display."""
161
  escaped = html.replace('"', '&quot;')
@@ -173,6 +201,9 @@ def run_evaluation(text: str, progress=gr.Progress()):
173
  from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
174
  from visualization.html_generator import generate_comparison_html
175
 
 
 
 
176
  # Validate input
177
  valid, result = validate_input(text)
178
  if not valid:
@@ -180,52 +211,33 @@ def run_evaluation(text: str, progress=gr.Progress()):
180
 
181
  text = result # Use cleaned text
182
 
183
- try:
184
- # Step 1: Download RWKV model if needed
185
- progress(0.05, desc="Checking RWKV7 model...")
186
- rwkv_model_path = download_rwkv_model(progress)
187
-
188
- # Step 2: Load Qwen model
189
- progress(0.1, desc="Loading Qwen3-1.7B-Base...")
190
- qwen_model, qwen_tokenizer = load_qwen_model()
191
 
192
- # Step 3: Evaluate Qwen
193
- progress(0.3, desc="Evaluating with Qwen3...")
 
194
  result_qwen = evaluate_hf_single_sample(
195
- qwen_model,
196
- qwen_tokenizer,
197
  text,
198
  bos_mode="add_newline_token"
199
  )
200
 
201
- # Step 4: Free Qwen memory
202
- progress(0.4, desc="Freeing memory...")
203
- del qwen_model
204
- if torch.cuda.is_available():
205
- torch.cuda.empty_cache()
206
- gc.collect()
207
-
208
- # Step 5: Load RWKV7 model
209
- progress(0.5, desc="Loading RWKV7-G1C-1.5B...")
210
- rwkv_model, rwkv_tokenizer = load_rwkv7_model(rwkv_model_path)
211
-
212
- # Step 6: Evaluate RWKV7
213
- progress(0.7, desc="Evaluating with RWKV7...")
214
  result_rwkv = evaluate_rwkv7_single_sample(
215
- rwkv_model,
216
- rwkv_tokenizer,
217
  text
218
  )
219
 
220
- # Step 7: Free RWKV memory
221
- progress(0.8, desc="Freeing memory...")
222
- del rwkv_model
223
- if torch.cuda.is_available():
224
- torch.cuda.empty_cache()
225
- gc.collect()
226
-
227
  # Step 8: Generate visualization
228
- progress(0.9, desc="Generating visualization...")
229
  html = generate_comparison_html(
230
  text=text,
231
  byte_losses_a=result_qwen["byte_wise_losses"],
@@ -243,11 +255,7 @@ def run_evaluation(text: str, progress=gr.Progress()):
243
  # Wrap HTML for iframe display
244
  wrapped_html = wrap_html_in_iframe(html)
245
 
246
- # Store HTML for download
247
- global _last_html_content
248
- _last_html_content = html
249
-
250
- progress(1.0, desc="Done!")
251
 
252
  return wrapped_html
253
 
@@ -272,10 +280,6 @@ def clear_inputs():
272
  return "", None
273
 
274
 
275
- # Global variable to store the last generated HTML for download
276
- _last_html_content = None
277
-
278
-
279
  # Build Gradio UI
280
  with gr.Blocks(
281
  title="UncheatableEval: Qwen3 vs RWKV7",
@@ -320,7 +324,6 @@ with gr.Blocks(
320
  with gr.Row():
321
  with gr.Column():
322
  output_html = gr.HTML(label="Visualization")
323
- download_file = gr.File(label="📥 Download HTML", visible=False)
324
 
325
  # Event handlers
326
  news_btn.click(fn=lambda: EXAMPLE_NEWS, outputs=[text_input])
@@ -332,45 +335,16 @@ with gr.Blocks(
332
  outputs=[text_input, output_html]
333
  )
334
 
335
- def run_and_prepare_download(text, progress=gr.Progress()):
336
- """Run evaluation and prepare download file."""
337
- wrapped_html = run_evaluation(text, progress)
338
-
339
- # Save HTML for download
340
- temp_file = tempfile.NamedTemporaryFile(
341
- mode='w',
342
- suffix='.html',
343
- delete=False,
344
- encoding='utf-8'
345
- )
346
- temp_file.write(_last_html_content)
347
- temp_file.close()
348
-
349
- return wrapped_html, temp_file.name
350
-
351
  run_btn.click(
352
- fn=run_and_prepare_download,
353
  inputs=[text_input],
354
- outputs=[output_html, download_btn]
355
  )
356
 
357
- gr.Markdown("""
358
- ---
359
- ### About
360
-
361
- This tool uses [UncheatableEval](https://github.com/Jellyfish042/UncheatableEval) to compare
362
- language model performance at the byte level.
363
-
364
- **Models:**
365
- - **Qwen3-1.7B-Base**: Transformer-based model from Alibaba
366
- - **RWKV7-G1C-1.5B**: Linear attention model from RWKV team
367
-
368
- **How it works:**
369
- 1. Both models predict each byte in the input text
370
- 2. Lower prediction loss = better compression = better understanding
371
- 3. The visualization shows where each model performs better or worse
372
- """)
373
-
374
 
375
  if __name__ == "__main__":
 
 
 
 
376
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
6
 
7
  import gc
8
  import os
 
9
  from pathlib import Path
10
 
11
  import gradio as gr
 
29
  MAX_TEXT_LENGTH = 4000
30
  MIN_TEXT_LENGTH = 10
31
 
32
+ # Global model cache
33
+ _qwen_model = None
34
+ _qwen_tokenizer = None
35
+ _rwkv_model = None
36
+ _rwkv_tokenizer = None
37
+ _rwkv_model_path = None
38
+
39
  # Example texts
40
  EXAMPLE_NEWS = """The rapid advancement of artificial intelligence has sparked both excitement and concern among researchers worldwide. While AI systems demonstrate remarkable capabilities in language understanding and generation, questions remain about their potential impact on employment and society."""
41
 
 
62
 
63
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
64
 
 
 
 
65
  # Download from HuggingFace Hub
66
  downloaded_path = hf_hub_download(
67
  repo_id="BlinkDL/rwkv7-g1",
 
135
  else:
136
  strategy = "cuda fp16"
137
 
138
+ # RWKV library automatically adds .pth extension, so remove it if present
139
+ if model_path.endswith('.pth'):
140
+ model_path = model_path[:-4]
141
+
142
  model = RWKV(model=model_path, strategy=strategy)
143
 
144
  vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
 
163
  return True, text
164
 
165
 
166
+ def initialize_models():
167
+ """Initialize and cache both models at startup."""
168
+ global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _rwkv_model_path
169
+
170
+ print("Initializing models...")
171
+
172
+ # Download RWKV model if needed
173
+ print("Checking RWKV7 model...")
174
+ _rwkv_model_path = download_rwkv_model()
175
+
176
+ # Load Qwen model
177
+ print("Loading Qwen3-1.7B-Base...")
178
+ _qwen_model, _qwen_tokenizer = load_qwen_model()
179
+
180
+ # Load RWKV7 model
181
+ print("Loading RWKV7-G1C-1.5B...")
182
+ _rwkv_model, _rwkv_tokenizer = load_rwkv7_model(_rwkv_model_path)
183
+
184
+ print("Models loaded successfully!")
185
+
186
+
187
  def wrap_html_in_iframe(html: str) -> str:
188
  """Wrap HTML in an iframe for Gradio display."""
189
  escaped = html.replace('"', '&quot;')
 
201
  from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
202
  from visualization.html_generator import generate_comparison_html
203
 
204
+ # Use cached models
205
+ global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer
206
+
207
  # Validate input
208
  valid, result = validate_input(text)
209
  if not valid:
 
211
 
212
  text = result # Use cleaned text
213
 
214
+ # Helper function to safely call progress
215
+ def safe_progress(value, desc):
216
+ try:
217
+ progress(value, desc=desc)
218
+ except:
219
+ pass
 
 
220
 
221
+ try:
222
+ # Step 1: Evaluate Qwen (using cached model)
223
+ safe_progress(0.2, "Evaluating with Qwen3...")
224
  result_qwen = evaluate_hf_single_sample(
225
+ _qwen_model,
226
+ _qwen_tokenizer,
227
  text,
228
  bos_mode="add_newline_token"
229
  )
230
 
231
+ # Step 2: Evaluate RWKV7 (using cached model)
232
+ safe_progress(0.6, "Evaluating with RWKV7...")
 
 
 
 
 
 
 
 
 
 
 
233
  result_rwkv = evaluate_rwkv7_single_sample(
234
+ _rwkv_model,
235
+ _rwkv_tokenizer,
236
  text
237
  )
238
 
 
 
 
 
 
 
 
239
  # Step 8: Generate visualization
240
+ safe_progress(0.9, "Generating visualization...")
241
  html = generate_comparison_html(
242
  text=text,
243
  byte_losses_a=result_qwen["byte_wise_losses"],
 
255
  # Wrap HTML for iframe display
256
  wrapped_html = wrap_html_in_iframe(html)
257
 
258
+ safe_progress(1.0, "Done!")
 
 
 
 
259
 
260
  return wrapped_html
261
 
 
280
  return "", None
281
 
282
 
 
 
 
 
283
  # Build Gradio UI
284
  with gr.Blocks(
285
  title="UncheatableEval: Qwen3 vs RWKV7",
 
324
  with gr.Row():
325
  with gr.Column():
326
  output_html = gr.HTML(label="Visualization")
 
327
 
328
  # Event handlers
329
  news_btn.click(fn=lambda: EXAMPLE_NEWS, outputs=[text_input])
 
335
  outputs=[text_input, output_html]
336
  )
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  run_btn.click(
339
+ fn=run_evaluation,
340
  inputs=[text_input],
341
+ outputs=[output_html]
342
  )
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  if __name__ == "__main__":
346
+ # Initialize models before launching the app
347
+ initialize_models()
348
+
349
+ # Launch the Gradio app
350
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)