AmrYassinIsFree commited on
Commit
f1c066b
Β·
1 Parent(s): 9d71632

enhancing the ui

Browse files
Files changed (2) hide show
  1. .github/workflows/sync-to-hf.yml +19 -0
  2. app.py +284 -80
.github/workflows/sync-to-hf.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to HuggingFace Spaces
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+
7
+ jobs:
8
+ sync:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v4
12
+ with:
13
+ fetch-depth: 0
14
+
15
+ - name: Push to HuggingFace Spaces
16
+ env:
17
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
18
+ run: |
19
+ git push https://amryassin:$HF_TOKEN@huggingface.co/spaces/amryassin/embedding-bench main --force
app.py CHANGED
@@ -16,7 +16,7 @@ from models import REGISTRY, ModelConfig
16
  from wrapper import load_model
17
 
18
  # ---------------------------------------------------------------------------
19
- # Page config
20
  # ---------------------------------------------------------------------------
21
  st.set_page_config(
22
  page_title="Embedding Bench",
@@ -24,26 +24,101 @@ st.set_page_config(
24
  layout="wide",
25
  )
26
 
27
- st.title("πŸ“ Embedding Bench")
28
- st.caption("Compare text embedding models on quality, speed & memory β€” all in your browser.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # ---------------------------------------------------------------------------
31
  # Sidebar β€” configuration
32
  # ---------------------------------------------------------------------------
33
- st.sidebar.header("Models")
 
 
34
  available_models = list(REGISTRY.keys())
35
  selected_models = st.sidebar.multiselect(
36
  "Select models",
37
  available_models,
38
  default=["mpnet", "bge-small"] if len(available_models) >= 2 else available_models[:1],
 
39
  )
40
 
41
- st.sidebar.header("Datasets")
42
  available_datasets = list(DATASET_PRESETS.keys())
43
  selected_datasets = st.sidebar.multiselect(
44
  "Select dataset presets",
45
  available_datasets,
46
  default=["sts"],
 
47
  )
48
 
49
  max_pairs = st.sidebar.number_input(
@@ -55,9 +130,10 @@ max_pairs = st.sidebar.number_input(
55
  help="Limits the number of pairs evaluated. Keep low for large datasets.",
56
  )
57
 
58
- st.sidebar.header("Speed & Memory")
59
- run_speed = st.sidebar.checkbox("Run speed benchmark", value=False)
60
- run_memory = st.sidebar.checkbox("Run memory benchmark", value=False)
 
61
 
62
  corpus_size = 500
63
  num_runs = 3
@@ -68,6 +144,8 @@ if run_speed or run_memory:
68
  if run_speed:
69
  num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
70
 
 
 
71
  # ---------------------------------------------------------------------------
72
  # Helpers
73
  # ---------------------------------------------------------------------------
@@ -107,6 +185,38 @@ def results_to_csv(results: list[dict]) -> str:
107
  return buf.getvalue()
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # ---------------------------------------------------------------------------
111
  # Run benchmark
112
  # ---------------------------------------------------------------------------
@@ -131,7 +241,6 @@ if run_btn:
131
  cfg = REGISTRY[model_key]
132
  result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}
133
 
134
- # Quality
135
  model = get_model(model_key)
136
  quality_results = {}
137
  for ds_cfg in ds_configs:
@@ -139,22 +248,20 @@ if run_btn:
139
  step += 1
140
  progress.progress(
141
  step / total_steps,
142
- text=f"Evaluating {cfg.name} on {ds_key}...",
143
  )
144
  quality_results[ds_key] = evaluate_quality(model, ds_cfg, max_pairs=max_pairs)
145
  result["quality"] = quality_results
146
 
147
- # Speed
148
  if run_speed:
149
  step += 1
150
- progress.progress(step / total_steps, text=f"Speed benchmark: {cfg.name}...")
151
  corpus = build_corpus(corpus_size, ds_configs[0])
152
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
153
 
154
- # Memory
155
  if run_memory:
156
  step += 1
157
- progress.progress(step / total_steps, text=f"Memory benchmark: {cfg.name}...")
158
  from evals.memory import evaluate_memory
159
  corpus = build_corpus(corpus_size, ds_configs[0])
160
  result["memory_mb"] = evaluate_memory(
@@ -167,7 +274,6 @@ if run_btn:
167
  time.sleep(0.3)
168
  progress.empty()
169
 
170
- # Store results in session state
171
  st.session_state["results"] = results
172
  st.session_state["selected_datasets"] = selected_datasets
173
 
@@ -175,31 +281,21 @@ if run_btn:
175
  # Display results
176
  # ---------------------------------------------------------------------------
177
  if "results" not in st.session_state:
178
- st.info("Configure options in the sidebar and hit **Run Benchmark**.")
 
 
 
 
 
 
179
  st.stop()
180
 
181
  results = st.session_state["results"]
182
- selected_datasets = st.session_state["selected_datasets"]
183
-
184
- # --- Results table ---
185
- st.header("Results")
186
- flat_rows = [flatten_result(r) for r in results]
187
- st.dataframe(flat_rows, use_container_width=True)
188
-
189
- # --- CSV download ---
190
- csv_data = results_to_csv(results)
191
- st.download_button(
192
- "πŸ“₯ Download CSV",
193
- data=csv_data,
194
- file_name="embedding_bench_results.csv",
195
- mime="text/csv",
196
- )
197
-
198
- # --- Charts ---
199
- st.header("Charts")
200
- models = [r["name"] for r in results]
201
 
202
- # Discover datasets
 
 
203
  ds_keys: list[str] = []
204
  for r in results:
205
  q = r.get("quality")
@@ -207,6 +303,88 @@ for r in results:
207
  ds_keys = list(q.keys())
208
  break
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  for ds_key in ds_keys:
211
  first_metrics = None
212
  for r in results:
@@ -219,17 +397,18 @@ for ds_key in ds_keys:
219
 
220
  if "spearman" in first_metrics:
221
  values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
222
- fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
223
- bars = ax.bar(models, values, color="#4C72B0")
224
- ax.set_ylabel("Spearman Correlation")
225
- ax.set_title(f"Quality β€” {ds_key}")
226
- ax.set_ylim(0, 1)
 
227
  for bar, v in zip(bars, values):
228
  ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
229
- f"{v:.4f}", ha="center", va="bottom", fontsize=9)
230
  plt.xticks(rotation=30, ha="right")
231
  plt.tight_layout()
232
- st.pyplot(fig)
233
  plt.close(fig)
234
  else:
235
  metric_names = ["mrr", "recall@1", "recall@5", "recall@10"]
@@ -237,52 +416,77 @@ for ds_key in ds_keys:
237
  width = 0.18
238
  colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
239
 
240
- fig, ax = plt.subplots(figsize=(max(8, len(models) * 2.2), 4.5))
 
241
  for i, (metric, color) in enumerate(zip(metric_names, colors)):
242
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
243
  offset = (i - 1.5) * width
244
- bars = ax.bar(x + offset, values, width, label=metric, color=color)
 
245
  for bar, v in zip(bars, values):
246
  ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
247
- f"{v:.2f}", ha="center", va="bottom", fontsize=7)
248
- ax.set_ylabel("Score")
249
- ax.set_title(f"Retrieval Quality β€” {ds_key}")
250
  ax.set_ylim(0, 1.15)
251
  ax.set_xticks(x)
252
- ax.set_xticklabels(models, rotation=30, ha="right")
253
- ax.legend()
 
254
  plt.tight_layout()
255
- st.pyplot(fig)
256
  plt.close(fig)
257
 
258
- # Speed chart
259
  speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
260
- if any(v > 0 for v in speed_values):
261
- fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
262
- bars = ax.bar(models, speed_values, color="#55A868")
263
- ax.set_ylabel("Sentences / second")
264
- ax.set_title("Encoding Speed")
265
- for bar, v in zip(bars, speed_values):
266
- if v > 0:
267
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
268
- str(v), ha="center", va="bottom", fontsize=9)
269
- plt.xticks(rotation=30, ha="right")
270
- plt.tight_layout()
271
- st.pyplot(fig)
272
- plt.close(fig)
273
-
274
- # Memory chart
275
  mem_values = [r.get("memory_mb", 0) for r in results]
276
- if any(v > 0 for v in mem_values):
277
- fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
278
- bars = ax.bar(models, mem_values, color="#C44E52")
279
- ax.set_ylabel("Peak Memory (MB)")
280
- ax.set_title("Memory Usage")
281
- for bar, v in zip(bars, mem_values):
282
- if v > 0:
283
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
284
- str(v), ha="center", va="bottom", fontsize=9)
285
- plt.xticks(rotation=30, ha="right")
286
- plt.tight_layout()
287
- st.pyplot(fig)
288
- plt.close(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from wrapper import load_model
17
 
18
  # ---------------------------------------------------------------------------
19
+ # Page config & custom CSS
20
  # ---------------------------------------------------------------------------
21
  st.set_page_config(
22
  page_title="Embedding Bench",
 
24
  layout="wide",
25
  )
26
 
27
+ st.markdown("""
28
+ <style>
29
+ /* Tighter top padding */
30
+ .block-container { padding-top: 1.5rem; padding-bottom: 1rem; }
31
+
32
+ /* Metric cards */
33
+ .metric-card {
34
+ background: linear-gradient(135deg, #1a1d23 0%, #22262e 100%);
35
+ border: 1px solid #333;
36
+ border-radius: 10px;
37
+ padding: 14px 18px;
38
+ text-align: center;
39
+ }
40
+ .metric-card .label {
41
+ font-size: 0.72rem;
42
+ color: #888;
43
+ text-transform: uppercase;
44
+ letter-spacing: 0.05em;
45
+ margin-bottom: 4px;
46
+ }
47
+ .metric-card .value {
48
+ font-size: 1.5rem;
49
+ font-weight: 700;
50
+ color: #fafafa;
51
+ }
52
+ .metric-card .sub {
53
+ font-size: 0.7rem;
54
+ color: #666;
55
+ margin-top: 2px;
56
+ }
57
+ .metric-card.best .value { color: #55A868; }
58
+ .metric-card.worst .value { color: #C44E52; }
59
+
60
+ /* Section divider */
61
+ .section-divider {
62
+ border: none;
63
+ border-top: 1px solid #2a2d35;
64
+ margin: 1.2rem 0;
65
+ }
66
+
67
+ /* Footer */
68
+ .footer {
69
+ text-align: center;
70
+ color: #555;
71
+ font-size: 0.75rem;
72
+ padding: 1.5rem 0 0.5rem;
73
+ border-top: 1px solid #222;
74
+ margin-top: 2rem;
75
+ }
76
+ .footer a { color: #4C72B0; text-decoration: none; }
77
+ </style>
78
+ """, unsafe_allow_html=True)
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Header
82
+ # ---------------------------------------------------------------------------
83
+ col_title, col_badge = st.columns([5, 1])
84
+ with col_title:
85
+ st.markdown("# πŸ“ Embedding Bench")
86
+ st.markdown(
87
+ "<span style='color:#888; font-size:0.95rem;'>"
88
+ "Compare text embedding models on quality, speed &amp; memory.</span>",
89
+ unsafe_allow_html=True,
90
+ )
91
+ with col_badge:
92
+ st.markdown(
93
+ "<div style='text-align:right; padding-top:18px;'>"
94
+ "<a href='https://github.com/amryassin/embedding-bench' target='_blank'>"
95
+ "<img src='https://img.shields.io/badge/GitHub-repo-blue?logo=github' /></a></div>",
96
+ unsafe_allow_html=True,
97
+ )
98
+
99
+ st.markdown("<hr class='section-divider'>", unsafe_allow_html=True)
100
 
101
  # ---------------------------------------------------------------------------
102
  # Sidebar β€” configuration
103
  # ---------------------------------------------------------------------------
104
+ st.sidebar.markdown("### βš™οΈ Configuration")
105
+
106
+ st.sidebar.markdown("**Models**")
107
  available_models = list(REGISTRY.keys())
108
  selected_models = st.sidebar.multiselect(
109
  "Select models",
110
  available_models,
111
  default=["mpnet", "bge-small"] if len(available_models) >= 2 else available_models[:1],
112
+ label_visibility="collapsed",
113
  )
114
 
115
+ st.sidebar.markdown("**Datasets**")
116
  available_datasets = list(DATASET_PRESETS.keys())
117
  selected_datasets = st.sidebar.multiselect(
118
  "Select dataset presets",
119
  available_datasets,
120
  default=["sts"],
121
+ label_visibility="collapsed",
122
  )
123
 
124
  max_pairs = st.sidebar.number_input(
 
130
  help="Limits the number of pairs evaluated. Keep low for large datasets.",
131
  )
132
 
133
+ st.sidebar.markdown("---")
134
+ st.sidebar.markdown("**Speed & Memory**")
135
+ run_speed = st.sidebar.checkbox("Speed benchmark", value=False)
136
+ run_memory = st.sidebar.checkbox("Memory benchmark", value=False)
137
 
138
  corpus_size = 500
139
  num_runs = 3
 
144
  if run_speed:
145
  num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
146
 
147
+ st.sidebar.markdown("---")
148
+
149
  # ---------------------------------------------------------------------------
150
  # Helpers
151
  # ---------------------------------------------------------------------------
 
185
  return buf.getvalue()
186
 
187
 
188
+ def render_metric_card(label: str, value: str, sub: str = "", css_class: str = "") -> str:
189
+ cls = f"metric-card {css_class}".strip()
190
+ sub_html = f"<div class='sub'>{sub}</div>" if sub else ""
191
+ return (
192
+ f"<div class='{cls}'>"
193
+ f"<div class='label'>{label}</div>"
194
+ f"<div class='value'>{value}</div>"
195
+ f"{sub_html}"
196
+ f"</div>"
197
+ )
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Chart style helper
202
+ # ---------------------------------------------------------------------------
203
+ CHART_BG = "#0E1117"
204
+ CHART_TEXT = "#CCCCCC"
205
+
206
+ def style_chart(fig, ax):
207
+ """Apply dark theme to a matplotlib chart."""
208
+ fig.patch.set_facecolor(CHART_BG)
209
+ ax.set_facecolor(CHART_BG)
210
+ ax.spines["top"].set_visible(False)
211
+ ax.spines["right"].set_visible(False)
212
+ ax.spines["left"].set_color("#444")
213
+ ax.spines["bottom"].set_color("#444")
214
+ ax.tick_params(colors=CHART_TEXT, labelsize=7)
215
+ ax.yaxis.label.set_color(CHART_TEXT)
216
+ ax.xaxis.label.set_color(CHART_TEXT)
217
+ ax.title.set_color("#FAFAFA")
218
+
219
+
220
  # ---------------------------------------------------------------------------
221
  # Run benchmark
222
  # ---------------------------------------------------------------------------
 
241
  cfg = REGISTRY[model_key]
242
  result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}
243
 
 
244
  model = get_model(model_key)
245
  quality_results = {}
246
  for ds_cfg in ds_configs:
 
248
  step += 1
249
  progress.progress(
250
  step / total_steps,
251
+ text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
252
  )
253
  quality_results[ds_key] = evaluate_quality(model, ds_cfg, max_pairs=max_pairs)
254
  result["quality"] = quality_results
255
 
 
256
  if run_speed:
257
  step += 1
258
+ progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
259
  corpus = build_corpus(corpus_size, ds_configs[0])
260
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
261
 
 
262
  if run_memory:
263
  step += 1
264
+ progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
265
  from evals.memory import evaluate_memory
266
  corpus = build_corpus(corpus_size, ds_configs[0])
267
  result["memory_mb"] = evaluate_memory(
 
274
  time.sleep(0.3)
275
  progress.empty()
276
 
 
277
  st.session_state["results"] = results
278
  st.session_state["selected_datasets"] = selected_datasets
279
 
 
281
  # Display results
282
  # ---------------------------------------------------------------------------
283
  if "results" not in st.session_state:
284
+ st.markdown(
285
+ "<div style='text-align:center; padding:3rem 0; color:#666;'>"
286
+ "<p style='font-size:2.5rem; margin-bottom:0.5rem;'>πŸ“</p>"
287
+ "<p style='font-size:1.1rem;'>Configure models &amp; datasets in the sidebar,<br>"
288
+ "then hit <b>Run Benchmark</b>.</p></div>",
289
+ unsafe_allow_html=True,
290
+ )
291
  st.stop()
292
 
293
  results = st.session_state["results"]
294
+ selected_datasets_display = st.session_state["selected_datasets"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
+ # ---------------------------------------------------------------------------
297
+ # Highlight cards
298
+ # ---------------------------------------------------------------------------
299
  ds_keys: list[str] = []
300
  for r in results:
301
  q = r.get("quality")
 
303
  ds_keys = list(q.keys())
304
  break
305
 
306
+ # Build a quick summary: best model per first dataset
307
+ if ds_keys:
308
+ first_ds = ds_keys[0]
309
+ first_metrics_sample = results[0].get("quality", {}).get(first_ds, {})
310
+ primary_metric = "spearman" if "spearman" in first_metrics_sample else "mrr"
311
+ primary_label = "Spearman" if primary_metric == "spearman" else "MRR"
312
+
313
+ scores = [
314
+ (r["name"], r.get("quality", {}).get(first_ds, {}).get(primary_metric, 0))
315
+ for r in results
316
+ ]
317
+ best = max(scores, key=lambda x: x[1])
318
+
319
+ speed_scores = [
320
+ (r["name"], r.get("speed", {}).get("sentences_per_second", 0))
321
+ for r in results
322
+ ]
323
+ fastest = max(speed_scores, key=lambda x: x[1]) if any(s[1] > 0 for s in speed_scores) else None
324
+
325
+ mem_scores = [
326
+ (r["name"], r.get("memory_mb", 0))
327
+ for r in results
328
+ ]
329
+ lightest = min((m for m in mem_scores if m[1] > 0), key=lambda x: x[1], default=None)
330
+
331
+ card_cols = st.columns(3)
332
+ with card_cols[0]:
333
+ st.markdown(render_metric_card(
334
+ f"Best {primary_label} ({first_ds})",
335
+ f"{best[1]:.4f}",
336
+ best[0],
337
+ "best",
338
+ ), unsafe_allow_html=True)
339
+ with card_cols[1]:
340
+ if fastest and fastest[1] > 0:
341
+ st.markdown(render_metric_card(
342
+ "Fastest",
343
+ f"{fastest[1]} sent/s",
344
+ fastest[0],
345
+ "best",
346
+ ), unsafe_allow_html=True)
347
+ else:
348
+ st.markdown(render_metric_card("Fastest", "β€”", "speed not measured"), unsafe_allow_html=True)
349
+ with card_cols[2]:
350
+ if lightest:
351
+ st.markdown(render_metric_card(
352
+ "Lightest",
353
+ f"{lightest[1]} MB",
354
+ lightest[0],
355
+ "best",
356
+ ), unsafe_allow_html=True)
357
+ else:
358
+ st.markdown(render_metric_card("Lightest", "β€”", "memory not measured"), unsafe_allow_html=True)
359
+
360
+ st.markdown("")
361
+
362
+ # ---------------------------------------------------------------------------
363
+ # Results table
364
+ # ---------------------------------------------------------------------------
365
+ st.markdown("#### πŸ“Š Detailed Results")
366
+ flat_rows = [flatten_result(r) for r in results]
367
+ st.dataframe(flat_rows, use_container_width=True, hide_index=True)
368
+
369
+ col_dl, _ = st.columns([1, 4])
370
+ with col_dl:
371
+ csv_data = results_to_csv(results)
372
+ st.download_button(
373
+ "πŸ“₯ Download CSV",
374
+ data=csv_data,
375
+ file_name="embedding_bench_results.csv",
376
+ mime="text/csv",
377
+ use_container_width=True,
378
+ )
379
+
380
+ st.markdown("<hr class='section-divider'>", unsafe_allow_html=True)
381
+
382
+ # ---------------------------------------------------------------------------
383
+ # Charts
384
+ # ---------------------------------------------------------------------------
385
+ st.markdown("#### πŸ“ˆ Charts")
386
+ models = [r["name"] for r in results]
387
+
388
  for ds_key in ds_keys:
389
  first_metrics = None
390
  for r in results:
 
397
 
398
  if "spearman" in first_metrics:
399
  values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
400
+ fig, ax = plt.subplots(figsize=(4, 2.4))
401
+ style_chart(fig, ax)
402
+ bars = ax.bar(models, values, color="#4C72B0", edgecolor="#5a82c0", linewidth=0.5)
403
+ ax.set_ylabel("Spearman", fontsize=8)
404
+ ax.set_title(f"Quality β€” {ds_key}", fontsize=9, pad=8)
405
+ ax.set_ylim(0, 1.08)
406
  for bar, v in zip(bars, values):
407
  ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
408
+ f"{v:.4f}", ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
409
  plt.xticks(rotation=30, ha="right")
410
  plt.tight_layout()
411
+ st.pyplot(fig, use_container_width=False)
412
  plt.close(fig)
413
  else:
414
  metric_names = ["mrr", "recall@1", "recall@5", "recall@10"]
 
416
  width = 0.18
417
  colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
418
 
419
+ fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4), 2.6))
420
+ style_chart(fig, ax)
421
  for i, (metric, color) in enumerate(zip(metric_names, colors)):
422
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
423
  offset = (i - 1.5) * width
424
+ bars = ax.bar(x + offset, values, width, label=metric, color=color,
425
+ edgecolor=color, linewidth=0.3, alpha=0.9)
426
  for bar, v in zip(bars, values):
427
  ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
428
+ f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
429
+ ax.set_ylabel("Score", fontsize=8)
430
+ ax.set_title(f"Retrieval Quality β€” {ds_key}", fontsize=9, pad=8)
431
  ax.set_ylim(0, 1.15)
432
  ax.set_xticks(x)
433
+ ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
434
+ ax.legend(fontsize=6, ncol=4, loc="upper right",
435
+ facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
436
  plt.tight_layout()
437
+ st.pyplot(fig, use_container_width=False)
438
  plt.close(fig)
439
 
440
+ # Speed & Memory side by side
441
  speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  mem_values = [r.get("memory_mb", 0) for r in results]
443
+ has_speed = any(v > 0 for v in speed_values)
444
+ has_memory = any(v > 0 for v in mem_values)
445
+
446
+ if has_speed or has_memory:
447
+ cols = st.columns(2 if has_speed and has_memory else 1)
448
+
449
+ if has_speed:
450
+ with cols[0]:
451
+ fig, ax = plt.subplots(figsize=(3.5, 2.4))
452
+ style_chart(fig, ax)
453
+ bars = ax.bar(models, speed_values, color="#55A868", edgecolor="#65b878", linewidth=0.5)
454
+ ax.set_ylabel("Sent / s", fontsize=8)
455
+ ax.set_title("Encoding Speed", fontsize=9, pad=8)
456
+ for bar, v in zip(bars, speed_values):
457
+ if v > 0:
458
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
459
+ str(v), ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
460
+ plt.xticks(rotation=30, ha="right")
461
+ plt.tight_layout()
462
+ st.pyplot(fig, use_container_width=False)
463
+ plt.close(fig)
464
+
465
+ if has_memory:
466
+ col_idx = 1 if has_speed else 0
467
+ with cols[col_idx]:
468
+ fig, ax = plt.subplots(figsize=(3.5, 2.4))
469
+ style_chart(fig, ax)
470
+ bars = ax.bar(models, mem_values, color="#C44E52", edgecolor="#d45e62", linewidth=0.5)
471
+ ax.set_ylabel("MB", fontsize=8)
472
+ ax.set_title("Memory Usage", fontsize=9, pad=8)
473
+ for bar, v in zip(bars, mem_values):
474
+ if v > 0:
475
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
476
+ str(v), ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
477
+ plt.xticks(rotation=30, ha="right")
478
+ plt.tight_layout()
479
+ st.pyplot(fig, use_container_width=False)
480
+ plt.close(fig)
481
+
482
+ # ---------------------------------------------------------------------------
483
+ # Footer
484
+ # ---------------------------------------------------------------------------
485
+ st.markdown(
486
+ "<div class='footer'>"
487
+ "Built with <a href='https://streamlit.io'>Streamlit</a> Β· "
488
+ "Models via <a href='https://huggingface.co'>HuggingFace</a> Β· "
489
+ "<a href='https://github.com/amryassin/embedding-bench'>Source on GitHub</a>"
490
+ "</div>",
491
+ unsafe_allow_html=True,
492
+ )