Jellyfish042 Claude Sonnet 4.5 commited on
Commit
cddd3a5
·
1 Parent(s): 6bbbdc0

Apply code formatting and update title

Browse files

Changes:
- Applied automatic code formatting (line length, quotes)
- Updated title: "Qwen3 vs RWKV7" → "RWKV-7 vs Qwen3"
- Reformatted multi-line function calls for consistency

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +24 -71
app.py CHANGED
@@ -50,10 +50,7 @@ def download_rwkv_model(progress=None):
50
 
51
  # Download from HuggingFace Hub
52
  downloaded_path = hf_hub_download(
53
- repo_id="BlinkDL/rwkv7-g1",
54
- filename=RWKV_MODEL_FILENAME,
55
- local_dir=str(MODELS_DIR),
56
- local_dir_use_symlinks=False
57
  )
58
 
59
  return downloaded_path
@@ -63,40 +60,18 @@ def load_qwen_model():
63
  """Load Qwen3-1.7B-Base model."""
64
  from transformers import AutoTokenizer, AutoModelForCausalLM
65
 
66
- tokenizer = AutoTokenizer.from_pretrained(
67
- QWEN_MODEL_ID,
68
- trust_remote_code=True
69
- )
70
 
71
  # Configure based on device
72
  if IS_CPU:
73
- model_kwargs = {
74
- "torch_dtype": torch.float32,
75
- "device_map": None,
76
- "trust_remote_code": True,
77
- "low_cpu_mem_usage": True
78
- }
79
- model = AutoModelForCausalLM.from_pretrained(
80
- QWEN_MODEL_ID,
81
- **model_kwargs
82
- ).eval()
83
  else:
84
- model_kwargs = {
85
- "torch_dtype": torch.bfloat16,
86
- "device_map": "auto",
87
- "trust_remote_code": True
88
- }
89
  try:
90
- model = AutoModelForCausalLM.from_pretrained(
91
- QWEN_MODEL_ID,
92
- attn_implementation="flash_attention_2",
93
- **model_kwargs
94
- ).eval()
95
  except Exception:
96
- model = AutoModelForCausalLM.from_pretrained(
97
- QWEN_MODEL_ID,
98
- **model_kwargs
99
- ).eval()
100
 
101
  return model, tokenizer
102
 
@@ -122,7 +97,7 @@ def load_rwkv7_model(model_path: str):
122
  strategy = "cuda fp16"
123
 
124
  # RWKV library automatically adds .pth extension, so remove it if present
125
- if model_path.endswith('.pth'):
126
  model_path = model_path[:-4]
127
 
128
  model = RWKV(model=model_path, strategy=strategy)
@@ -174,14 +149,14 @@ def wrap_html_in_iframe(html: str) -> str:
174
  """Wrap HTML in an iframe for Gradio display."""
175
  # For srcdoc attribute, we only need to escape quotes
176
  # The HTML entities inside (like &quot;, &#10;) should remain as-is
177
- escaped = html.replace('"', '&quot;')
178
- return f'''
179
  <div style="width:100%;height:700px;border:1px solid #ddd;border-radius:8px;overflow:hidden;">
180
  <iframe srcdoc="{escaped}"
181
  style="width:100%;height:100%;border:none;"
182
  sandbox="allow-scripts"></iframe>
183
  </div>
184
- '''
185
 
186
 
187
  def run_evaluation(text: str, progress=gr.Progress()):
@@ -202,20 +177,11 @@ def run_evaluation(text: str, progress=gr.Progress()):
202
  try:
203
  # Step 1: Evaluate Qwen (using cached model)
204
  progress(0, desc="Evaluating with Qwen3...")
205
- result_qwen = evaluate_hf_single_sample(
206
- _qwen_model,
207
- _qwen_tokenizer,
208
- text,
209
- bos_mode="add_newline_token"
210
- )
211
 
212
  # Step 2: Evaluate RWKV7 (using cached model)
213
  progress(0, desc="Evaluating with RWKV7...")
214
- result_rwkv = evaluate_rwkv7_single_sample(
215
- _rwkv_model,
216
- _rwkv_tokenizer,
217
- text
218
- )
219
 
220
  # Step 3: Generate visualization
221
  progress(0, desc="Generating visualization...")
@@ -230,7 +196,7 @@ def run_evaluation(text: str, progress=gr.Progress()):
230
  tokenizer_a=result_rwkv["tokenizer"],
231
  tokenizer_b=result_qwen["tokenizer"],
232
  model_type_a="rwkv7",
233
- model_type_b="hf"
234
  )
235
 
236
  # Wrap HTML for iframe display
@@ -242,11 +208,7 @@ def run_evaluation(text: str, progress=gr.Progress()):
242
  if torch.cuda.is_available():
243
  torch.cuda.empty_cache()
244
  gc.collect()
245
- raise gr.Error(
246
- "GPU memory insufficient. Please try:\n"
247
- "1. Use shorter text\n"
248
- "2. Wait a moment and try again"
249
- )
250
  except Exception as e:
251
  if torch.cuda.is_available():
252
  torch.cuda.empty_cache()
@@ -260,15 +222,13 @@ def clear_inputs():
260
 
261
 
262
  # Build Gradio UI
263
- with gr.Blocks(
264
- title="Compression-Lens: Qwen3 vs RWKV7",
265
- theme=gr.themes.Soft()
266
- ) as demo:
267
- gr.Markdown("""
268
- # 🔬 Compression-Lens: Qwen3 vs RWKV7 Byte-Level Comparison
269
-
270
- Compare the byte-level prediction performance between **Qwen3-1.7B-Base** and **RWKV7-G1C-1.5B**.
271
- """)
272
 
273
  with gr.Row():
274
  with gr.Column(scale=1):
@@ -290,16 +250,9 @@ with gr.Blocks(
290
  output_html = gr.HTML(label="Visualization")
291
 
292
  # Event handlers
293
- clear_btn.click(
294
- fn=clear_inputs,
295
- outputs=[text_input, output_html]
296
- )
297
 
298
- run_btn.click(
299
- fn=run_evaluation,
300
- inputs=[text_input],
301
- outputs=[output_html]
302
- )
303
 
304
 
305
  if __name__ == "__main__":
 
50
 
51
  # Download from HuggingFace Hub
52
  downloaded_path = hf_hub_download(
53
+ repo_id="BlinkDL/rwkv7-g1", filename=RWKV_MODEL_FILENAME, local_dir=str(MODELS_DIR), local_dir_use_symlinks=False
 
 
 
54
  )
55
 
56
  return downloaded_path
 
60
  """Load Qwen3-1.7B-Base model."""
61
  from transformers import AutoTokenizer, AutoModelForCausalLM
62
 
63
+ tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)
 
 
 
64
 
65
  # Configure based on device
66
  if IS_CPU:
67
+ model_kwargs = {"torch_dtype": torch.float32, "device_map": None, "trust_remote_code": True, "low_cpu_mem_usage": True}
68
+ model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
 
 
 
 
 
 
 
 
69
  else:
70
+ model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "auto", "trust_remote_code": True}
 
 
 
 
71
  try:
72
+ model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, attn_implementation="flash_attention_2", **model_kwargs).eval()
 
 
 
 
73
  except Exception:
74
+ model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
 
 
 
75
 
76
  return model, tokenizer
77
 
 
97
  strategy = "cuda fp16"
98
 
99
  # RWKV library automatically adds .pth extension, so remove it if present
100
+ if model_path.endswith(".pth"):
101
  model_path = model_path[:-4]
102
 
103
  model = RWKV(model=model_path, strategy=strategy)
 
149
  """Wrap HTML in an iframe for Gradio display."""
150
  # For srcdoc attribute, we only need to escape quotes
151
  # The HTML entities inside (like &quot;, &#10;) should remain as-is
152
+ escaped = html.replace('"', "&quot;")
153
+ return f"""
154
  <div style="width:100%;height:700px;border:1px solid #ddd;border-radius:8px;overflow:hidden;">
155
  <iframe srcdoc="{escaped}"
156
  style="width:100%;height:100%;border:none;"
157
  sandbox="allow-scripts"></iframe>
158
  </div>
159
+ """
160
 
161
 
162
  def run_evaluation(text: str, progress=gr.Progress()):
 
177
  try:
178
  # Step 1: Evaluate Qwen (using cached model)
179
  progress(0, desc="Evaluating with Qwen3...")
180
+ result_qwen = evaluate_hf_single_sample(_qwen_model, _qwen_tokenizer, text, bos_mode="add_newline_token")
 
 
 
 
 
181
 
182
  # Step 2: Evaluate RWKV7 (using cached model)
183
  progress(0, desc="Evaluating with RWKV7...")
184
+ result_rwkv = evaluate_rwkv7_single_sample(_rwkv_model, _rwkv_tokenizer, text)
 
 
 
 
185
 
186
  # Step 3: Generate visualization
187
  progress(0, desc="Generating visualization...")
 
196
  tokenizer_a=result_rwkv["tokenizer"],
197
  tokenizer_b=result_qwen["tokenizer"],
198
  model_type_a="rwkv7",
199
+ model_type_b="hf",
200
  )
201
 
202
  # Wrap HTML for iframe display
 
208
  if torch.cuda.is_available():
209
  torch.cuda.empty_cache()
210
  gc.collect()
211
+ raise gr.Error("GPU memory insufficient. Please try:\n" "1. Use shorter text\n" "2. Wait a moment and try again")
 
 
 
 
212
  except Exception as e:
213
  if torch.cuda.is_available():
214
  torch.cuda.empty_cache()
 
222
 
223
 
224
  # Build Gradio UI
225
+ with gr.Blocks(title="Compression-Lens: RWKV-7 vs Qwen3", theme=gr.themes.Soft()) as demo:
226
+ gr.Markdown(
227
+ """
228
+ # 🔬 Compression-Lens: RWKV-7 vs Qwen3 Byte-Level Comparison
229
+ Compare the byte-level prediction performance between **RWKV7-G1C-1.5B** and **Qwen3-1.7B-Base**.
230
+ """
231
+ )
 
 
232
 
233
  with gr.Row():
234
  with gr.Column(scale=1):
 
250
  output_html = gr.HTML(label="Visualization")
251
 
252
  # Event handlers
253
+ clear_btn.click(fn=clear_inputs, outputs=[text_input, output_html])
 
 
 
254
 
255
+ run_btn.click(fn=run_evaluation, inputs=[text_input], outputs=[output_html])
 
 
 
 
256
 
257
 
258
  if __name__ == "__main__":