Lev Israel commited on
Commit
018c4c5
·
0 Parent(s):

Initial Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ venv/
8
+ .venv/
9
+ env/
10
+
11
+ # IDE
12
+ .vscode/
13
+ .idea/
14
+ *.swp
15
+ *.swo
16
+
17
+ # OS
18
+ .DS_Store
19
+ Thumbs.db
20
+
21
+ # Secrets (never commit these!)
22
+ .env
23
+ *.env
24
+
25
+ # Leaderboard data (regenerated on Space)
26
+ benchmark_data/leaderboard.json
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rabbinic Hebrew/Aramaic Embedding Evaluation
2
+
3
+ A Hugging Face Space for evaluating embedding models on Rabbinic Hebrew and Aramaic texts using cross-lingual retrieval benchmarks.
4
+
5
+ ## Overview
6
+
7
+ This tool helps identify which embedding models best capture the semantics of Rabbinic Hebrew and Aramaic by measuring how well they align source texts with their English translations. Models that excel at this task are likely to produce high-quality embeddings for untranslated texts.
8
+
9
+ ## Evaluation Approach
10
+
11
+ Given a Hebrew/Aramaic text, the benchmark tests whether the embedding model can find its correct English translation from a pool of candidates. This cross-lingual retrieval task measures semantic alignment across languages.
12
+
13
+ ### Metrics
14
+
15
+ | Metric | Description |
16
+ |--------|-------------|
17
+ | **Recall@1** | % of queries where correct translation is the top result |
18
+ | **Recall@5** | % where correct translation is in top 5 results |
19
+ | **Recall@10** | % where correct translation is in top 10 results |
20
+ | **MRR** | Mean Reciprocal Rank (average of 1/rank of correct answer) |
21
+
22
+ ## Corpus
23
+
24
+ The benchmark includes diverse texts from Sefaria with English translations:
25
+
26
+ - **Talmud**: Bavli and Yerushalmi (Aramaic + Hebrew)
27
+ - **Mishnah**: All tractates (Rabbinic Hebrew)
28
+ - **Midrash**: Midrash Rabbah (Hebrew/Aramaic)
29
+ - **Tanakh Commentary**: Rashi and Ramban on Tanakh (Hebrew)
30
+ - **Hasidic/Kabbalistic**: Likutei Moharan, Tomer Devorah (Hebrew)
31
+ - **Halacha**: Sefer HaHinuch, Intro to Shev Shmateta (Hebrew)
32
+
33
+ ## Usage
34
+
35
+ 1. Select a model from the curated list or enter any Hugging Face model ID
36
+ 2. Click "Run Evaluation"
37
+ 3. View results and compare with the leaderboard
38
+
39
+ ## Models
40
+
41
+ ### Curated Models
42
+ - `intfloat/multilingual-e5-large`
43
+ - `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`
44
+ - `BAAI/bge-m3`
45
+
46
+ You can also evaluate any sentence-transformer compatible model from Hugging Face Hub.
47
+
48
+ ## Local Development
49
+
50
+ ```bash
51
+ pip install -r requirements.txt
52
+ python app.py
53
+ ```
54
+
55
+ ## License
56
+
57
+ MIT
58
+
app.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for Rabbinic Hebrew/Aramaic Embedding Evaluation.
3
+
4
+ A Hugging Face Space for evaluating embedding models on cross-lingual
5
+ retrieval between Hebrew/Aramaic source texts and English translations.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+
13
+ import gradio as gr
14
+ import pandas as pd
15
+ import plotly.graph_objects as go
16
+
17
+ from data_loader import load_benchmark_dataset, get_benchmark_stats
18
+ from models import (
19
+ CURATED_MODELS,
20
+ API_MODELS,
21
+ ALL_MODELS,
22
+ get_curated_model_choices,
23
+ get_api_model_choices,
24
+ get_all_model_choices,
25
+ load_model,
26
+ validate_model_id,
27
+ is_api_model,
28
+ requires_api_key,
29
+ api_key_optional,
30
+ get_api_key_type,
31
+ get_api_key_env_var,
32
+ )
33
+ from evaluation import (
34
+ EvaluationResults,
35
+ evaluate_model,
36
+ evaluate_model_streaming,
37
+ compute_similarity_matrix,
38
+ get_rank_distribution,
39
+ )
40
+
41
+ # Paths
42
+ BENCHMARK_PATH = "benchmark_data/benchmark.json"
43
+ LEADERBOARD_PATH = "benchmark_data/leaderboard.json"
44
+
45
+ # Global state
46
+ _benchmark_data = None
47
+ _leaderboard = []
48
+
49
+
50
+ def load_benchmark():
51
+ """Load benchmark data, with fallback to sample data."""
52
+ global _benchmark_data
53
+
54
+ if _benchmark_data is not None:
55
+ return _benchmark_data
56
+
57
+ try:
58
+ _benchmark_data = load_benchmark_dataset(BENCHMARK_PATH)
59
+ print(f"Loaded {len(_benchmark_data)} benchmark pairs")
60
+ except FileNotFoundError:
61
+ print("Benchmark not found, using sample data")
62
+ # Create minimal sample data for testing
63
+ _benchmark_data = [
64
+ {
65
+ "ref": "Sample.1",
66
+ "he": "בראשית ברא אלהים את השמים ואת הארץ",
67
+ "en": "In the beginning God created the heaven and the earth",
68
+ "category": "Sample",
69
+ },
70
+ {
71
+ "ref": "Sample.2",
72
+ "he": "והארץ היתה תהו ובהו וחשך על פני תהום",
73
+ "en": "And the earth was without form, and void; and darkness was upon the face of the deep",
74
+ "category": "Sample",
75
+ },
76
+ ]
77
+
78
+ return _benchmark_data
79
+
80
+
81
+ def load_leaderboard():
82
+ """Load saved leaderboard results."""
83
+ global _leaderboard
84
+
85
+ try:
86
+ with open(LEADERBOARD_PATH, "r") as f:
87
+ _leaderboard = json.load(f)
88
+ except FileNotFoundError:
89
+ _leaderboard = []
90
+
91
+ return _leaderboard
92
+
93
+
94
+ def save_leaderboard():
95
+ """Save leaderboard to file."""
96
+ global _leaderboard
97
+
98
+ Path(LEADERBOARD_PATH).parent.mkdir(parents=True, exist_ok=True)
99
+ with open(LEADERBOARD_PATH, "w") as f:
100
+ json.dump(_leaderboard, f, indent=2)
101
+
102
+
103
+ def add_to_leaderboard(results: EvaluationResults):
104
+ """Add evaluation results to leaderboard."""
105
+ global _leaderboard
106
+
107
+ entry = results.to_dict()
108
+ entry["timestamp"] = datetime.now().isoformat()
109
+
110
+ # Remove existing entry for same model
111
+ _leaderboard = [e for e in _leaderboard if e["model_id"] != results.model_id]
112
+ _leaderboard.append(entry)
113
+
114
+ # Sort by MRR descending
115
+ _leaderboard.sort(key=lambda x: x["mrr"], reverse=True)
116
+
117
+ save_leaderboard()
118
+
119
+
120
+ def format_leaderboard_df():
121
+ """Format leaderboard as pandas DataFrame for display."""
122
+ load_leaderboard()
123
+
124
+ if not _leaderboard:
125
+ return pd.DataFrame(columns=[
126
+ "#", "Model", "MRR", "R@1", "R@5", "R@10",
127
+ "Bitext", "TrueSim", "RandSim", "N"
128
+ ])
129
+
130
+ rows = []
131
+ for i, entry in enumerate(_leaderboard, 1):
132
+ rows.append({
133
+ "#": i,
134
+ "Model": entry.get("model_name", entry["model_id"]),
135
+ "MRR": f"{entry['mrr']:.3f}",
136
+ "R@1": f"{entry['recall_at_1']:.1%}",
137
+ "R@5": f"{entry['recall_at_5']:.1%}",
138
+ "R@10": f"{entry['recall_at_10']:.1%}",
139
+ "Bitext": f"{entry['bitext_accuracy']:.1%}",
140
+ "TrueSim": f"{entry['avg_true_pair_similarity']:.3f}",
141
+ "RandSim": f"{entry['avg_random_pair_similarity']:.3f}",
142
+ "N": entry["num_pairs"],
143
+ })
144
+
145
+ return pd.DataFrame(rows)
146
+
147
+
148
+ def run_evaluation(
149
+ model_choice: str,
150
+ custom_model_id: str,
151
+ api_key: str,
152
+ max_pairs: int,
153
+ ):
154
+ """
155
+ Run evaluation for the selected model (generator for streaming status updates).
156
+
157
+ Args:
158
+ model_choice: Selected curated model or "custom"
159
+ custom_model_id: Custom model ID if selected
160
+ api_key: API key for API-based models
161
+ max_pairs: Maximum pairs to evaluate
162
+
163
+ Yields:
164
+ Tuples of (status, results, leaderboard)
165
+ """
166
+ # Helper to yield status updates
167
+ def status_update(msg):
168
+ return (msg, gr.update(), gr.update())
169
+
170
+ # Determine which model to use
171
+ if model_choice == "custom":
172
+ model_id = custom_model_id.strip()
173
+ is_valid, error = validate_model_id(model_id)
174
+ if not is_valid:
175
+ yield (
176
+ f"❌ {error}",
177
+ f"❌ Invalid model ID: {error}",
178
+ format_leaderboard_df(),
179
+ )
180
+ return
181
+ else:
182
+ model_id = model_choice
183
+
184
+ # Check if API key is required but not provided
185
+ if requires_api_key(model_id):
186
+ api_key = api_key.strip() if api_key else ""
187
+ env_var = get_api_key_env_var(model_id)
188
+ key_type = get_api_key_type(model_id)
189
+
190
+ # Skip API key check for models that support Application Default Credentials
191
+ if not api_key and not os.environ.get(env_var) and not api_key_optional(model_id):
192
+ yield (
193
+ "❌ API key required",
194
+ f"❌ API key required for {model_id}. Please enter your {key_type.upper()} API key or set the {env_var} environment variable.",
195
+ format_leaderboard_df(),
196
+ )
197
+ return
198
+
199
+ yield status_update(f"⏳ Loading benchmark data...")
200
+ benchmark = load_benchmark()
201
+
202
+ if max_pairs and max_pairs < len(benchmark):
203
+ benchmark = benchmark[:max_pairs]
204
+
205
+ yield status_update(f"⏳ Loading model: {model_id}...")
206
+
207
+ try:
208
+ # Pass API key for API-based models
209
+ model = load_model(model_id, api_key=api_key if api_key else None)
210
+ except Exception as e:
211
+ yield (
212
+ "❌ Model load failed",
213
+ f"❌ Failed to load model: {str(e)}",
214
+ format_leaderboard_df(),
215
+ )
216
+ return
217
+
218
+ # Stream progress updates during evaluation
219
+ try:
220
+ results = None
221
+ for item in evaluate_model_streaming(model, benchmark, batch_size=32):
222
+ if isinstance(item, str):
223
+ # Progress update
224
+ yield status_update(item)
225
+ else:
226
+ # Final results
227
+ results = item
228
+ except Exception as e:
229
+ yield (
230
+ "❌ Evaluation failed",
231
+ f"❌ Evaluation failed: {str(e)}",
232
+ format_leaderboard_df(),
233
+ )
234
+ return
235
+
236
+ yield status_update("⏳ Saving results...")
237
+ add_to_leaderboard(results)
238
+
239
+ # Format results summary
240
+ summary = f"""## Results for {results.model_name}
241
+
242
+ | Metric | Value |
243
+ |--------|-------|
244
+ | **MRR** | {results.mrr:.4f} |
245
+ | **Recall@1** | {results.recall_at_1:.1%} |
246
+ | **Recall@5** | {results.recall_at_5:.1%} |
247
+ | **Recall@10** | {results.recall_at_10:.1%} |
248
+ | **Bitext Accuracy** | {results.bitext_accuracy:.1%} |
249
+ | **Avg True Pair Sim** | {results.avg_true_pair_similarity:.4f} |
250
+ | **Avg Random Pair Sim** | {results.avg_random_pair_similarity:.4f} |
251
+ | **Pairs Evaluated** | {results.num_pairs:,} |
252
+ """
253
+
254
+ # Final yield with all results (clear status)
255
+ yield (
256
+ "✅ Complete!",
257
+ summary,
258
+ format_leaderboard_df(),
259
+ )
260
+
261
+
262
+ def create_leaderboard_comparison():
263
+ """Create comparison chart of all models on leaderboard."""
264
+ load_leaderboard()
265
+
266
+ if len(_leaderboard) < 2:
267
+ return None
268
+
269
+ models = [e.get("model_name", e["model_id"]) for e in _leaderboard]
270
+ mrr = [e["mrr"] for e in _leaderboard]
271
+ r1 = [e["recall_at_1"] for e in _leaderboard]
272
+ r5 = [e["recall_at_5"] for e in _leaderboard]
273
+ r10 = [e["recall_at_10"] for e in _leaderboard]
274
+ bitext = [e["bitext_accuracy"] for e in _leaderboard]
275
+
276
+ fig = go.Figure()
277
+
278
+ fig.add_trace(go.Bar(name="MRR", x=models, y=mrr, marker_color="#2E86AB"))
279
+ fig.add_trace(go.Bar(name="R@1", x=models, y=r1, marker_color="#A23B72"))
280
+ fig.add_trace(go.Bar(name="R@5", x=models, y=r5, marker_color="#F18F01"))
281
+ fig.add_trace(go.Bar(name="R@10", x=models, y=r10, marker_color="#C73E1D"))
282
+ fig.add_trace(go.Bar(name="Bitext Acc", x=models, y=bitext, marker_color="#6B5B95"))
283
+
284
+ fig.update_layout(
285
+ title="Model Comparison",
286
+ yaxis_title="Score",
287
+ yaxis_range=[0, 1],
288
+ barmode="group",
289
+ template="plotly_white",
290
+ height=400,
291
+ )
292
+
293
+ return fig
294
+
295
+
296
+ def update_model_inputs_visibility(choice):
297
+ """Show/hide custom model input and API key based on selection."""
298
+ show_custom = (choice == "custom")
299
+ show_api_key = requires_api_key(choice) if choice != "custom" else False
300
+
301
+ # Update API key label based on model type
302
+ if show_api_key:
303
+ key_type = get_api_key_type(choice)
304
+ env_var = get_api_key_env_var(choice)
305
+ is_optional = api_key_optional(choice)
306
+
307
+ if key_type == "voyage":
308
+ label = "Voyage AI API Key"
309
+ placeholder = f"Enter your Voyage AI API key (or set {env_var} env var)"
310
+ elif key_type == "gemini":
311
+ label = "Gemini API Key (optional if using gcloud)"
312
+ placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
313
+ else:
314
+ label = "OpenAI API Key"
315
+ placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
316
+ return (
317
+ gr.update(visible=show_custom),
318
+ gr.update(visible=show_api_key, label=label, placeholder=placeholder),
319
+ )
320
+
321
+ return (
322
+ gr.update(visible=show_custom),
323
+ gr.update(visible=show_api_key),
324
+ )
325
+
326
+
327
+ # Build the Gradio interface
328
+ def create_app():
329
+ """Create and return the Gradio app."""
330
+
331
+ # Get all model choices - local models first, then API models
332
+ model_choices = []
333
+
334
+ # Local models
335
+ for model_id, info in CURATED_MODELS.items():
336
+ model_choices.append((f"🖥️ {info['name']}", model_id))
337
+
338
+ # API models
339
+ for model_id, info in API_MODELS.items():
340
+ model_choices.append((f"🌐 {info['name']}", model_id))
341
+
342
+ # Custom option
343
+ model_choices.append(("⚙️ Custom Model (enter ID below)", "custom"))
344
+
345
+ # Load initial data
346
+ load_benchmark()
347
+ load_leaderboard()
348
+ benchmark_stats = get_benchmark_stats(_benchmark_data) if _benchmark_data else {}
349
+
350
+ with gr.Blocks(
351
+ title="Rabbinic Embedding Benchmark",
352
+ theme=gr.themes.Soft(
353
+ primary_hue="blue",
354
+ secondary_hue="orange",
355
+ font=gr.themes.GoogleFont("Source Sans Pro"),
356
+ ),
357
+ css="""
358
+ .main-header {
359
+ text-align: center;
360
+ margin-bottom: 1rem;
361
+ }
362
+ .stats-box {
363
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
364
+ color: white;
365
+ padding: 1rem;
366
+ border-radius: 8px;
367
+ margin: 0.5rem 0;
368
+ }
369
+ """,
370
+ ) as app:
371
+
372
+ gr.Markdown(
373
+ """
374
+ # 📚 Rabbinic Hebrew/Aramaic Embedding Benchmark
375
+
376
+ Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic
377
+ source texts and their English translations from Sefaria.
378
+
379
+ **How it works:** Given a Hebrew/Aramaic text, can the model find its correct
380
+ English translation from a pool of candidates? Models that excel at this task
381
+ produce high-quality embeddings for Rabbinic literature.
382
+ """,
383
+ elem_classes=["main-header"],
384
+ )
385
+
386
+ with gr.Row():
387
+ with gr.Column(scale=1):
388
+ gr.Markdown(f"""
389
+ ### 📊 Benchmark Stats
390
+ - **Total Pairs:** {benchmark_stats.get('total_pairs', 'N/A'):,}
391
+ - **Categories:** {len(benchmark_stats.get('categories', {}))}
392
+ - **Avg Hebrew Length:** {benchmark_stats.get('avg_he_length', 0):.0f} chars
393
+ """)
394
+
395
+ with gr.Column(scale=1):
396
+ gr.Markdown("""
397
+ ### 📏 Metrics
398
+ - **MRR:** Mean Reciprocal Rank
399
+ - **R@k:** Recall at k (correct in top k)
400
+ - **Bitext Acc:** True vs random pair classification
401
+ """)
402
+
403
+ gr.Markdown("---")
404
+
405
+ with gr.Tabs():
406
+ with gr.TabItem("🔬 Evaluate Model"):
407
+ with gr.Row():
408
+ with gr.Column(scale=2):
409
+ model_dropdown = gr.Dropdown(
410
+ choices=model_choices,
411
+ value=model_choices[0][1],
412
+ label="Select Model",
413
+ info="Choose a curated model or enter a custom Hugging Face model ID",
414
+ )
415
+
416
+ custom_model_input = gr.Textbox(
417
+ label="Custom Model ID",
418
+ placeholder="e.g., organization/model-name",
419
+ visible=False,
420
+ )
421
+
422
+ api_key_input = gr.Textbox(
423
+ label="API Key",
424
+ placeholder="Enter your API key (or set appropriate env var)",
425
+ type="password",
426
+ visible=False,
427
+ info="Required for API-based models (OpenAI, Voyage AI). Your key is not stored.",
428
+ )
429
+
430
+ total_pairs = benchmark_stats.get('total_pairs', 1000)
431
+ max_pairs_slider = gr.Slider(
432
+ minimum=100,
433
+ maximum=total_pairs,
434
+ value=total_pairs,
435
+ step=100,
436
+ label="Max Pairs to Evaluate",
437
+ info="Use fewer pairs for faster evaluation",
438
+ )
439
+
440
+ with gr.Column(scale=3):
441
+ evaluate_btn = gr.Button(
442
+ "🚀 Run Evaluation",
443
+ variant="primary",
444
+ size="lg",
445
+ )
446
+
447
+ status_text = gr.Markdown("")
448
+
449
+ results_markdown = gr.Markdown("")
450
+
451
+ with gr.TabItem("🏆 Leaderboard"):
452
+ leaderboard_table = gr.Dataframe(
453
+ value=format_leaderboard_df(),
454
+ label="Model Rankings",
455
+ interactive=False,
456
+ )
457
+
458
+ refresh_btn = gr.Button("🔄 Refresh Leaderboard")
459
+
460
+ comparison_plot = gr.Plot(label="Model Comparison")
461
+
462
+ gr.Markdown("""
463
+ ---
464
+ ### About
465
+
466
+ This benchmark evaluates embedding models for Rabbinic Hebrew and Aramaic texts using
467
+ cross-lingual retrieval.
468
+
469
+ All texts and translations sourced from [Sefaria](https://www.sefaria.org).
470
+ """)
471
+
472
+ # Event handlers
473
+ model_dropdown.change(
474
+ fn=update_model_inputs_visibility,
475
+ inputs=[model_dropdown],
476
+ outputs=[custom_model_input, api_key_input],
477
+ )
478
+
479
+ evaluate_btn.click(
480
+ fn=run_evaluation,
481
+ inputs=[model_dropdown, custom_model_input, api_key_input, max_pairs_slider],
482
+ outputs=[status_text, results_markdown, leaderboard_table],
483
+ show_progress="hidden",
484
+ )
485
+
486
+ refresh_btn.click(
487
+ fn=lambda: (format_leaderboard_df(), create_leaderboard_comparison()),
488
+ outputs=[leaderboard_table, comparison_plot],
489
+ )
490
+
491
+ return app
492
+
493
+
494
+ # Main entry point
495
+ if __name__ == "__main__":
496
+ app = create_app()
497
+ app.launch()
498
+
benchmark-stats.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Total pairs: 3,721
2
+ Categories:
3
+ - Halacha: 160
4
+ - Hasidic/Kabbalistic: 304
5
+ - Jerusalem Talmud: 520
6
+ - Midrash Rabbah: 400
7
+ - Mishnah: 789
8
+ - Mussar/Ethics: 108
9
+ - Philosophy: 240
10
+ - Talmud: 480
11
+ - Tanakh Commentary: 680
12
+ - Targum: 40
13
+ Average Hebrew text length: 650 chars
14
+ Average English text length: 995 chars
benchmark_data/benchmark.json ADDED
The diff for this file is too large to render. See raw diff
 
build_benchmark.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to build the benchmark dataset from Sefaria API.
4
+
5
+ Run this script to fetch and cache parallel Hebrew/Aramaic-English text pairs
6
+ from Sefaria for use in the embedding evaluation benchmark.
7
+
8
+ Usage:
9
+ python build_benchmark.py [--max-per-text N] [--total N] [--output PATH]
10
+ """
11
+
12
+ import argparse
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import requests
17
+
18
+ from data_loader import (
19
+ build_benchmark_dataset,
20
+ get_benchmark_stats,
21
+ get_index_from_sefaria,
22
+ set_sefaria_host,
23
+ get_sefaria_host,
24
+ BENCHMARK_TEXTS,
25
+ )
26
+
27
+
28
+ def get_name_suggestions(title: str, host: str, limit: int = 5) -> list[str]:
29
+ """Get name suggestions from the Sefaria name API."""
30
+ try:
31
+ url = f"{host}/api/name/{title}"
32
+ response = requests.get(url, params={"limit": limit, "type": "ref"}, timeout=10)
33
+ if response.status_code == 200:
34
+ data = response.json()
35
+ # Return completions that are refs (book titles)
36
+ completions = data.get("completions", [])
37
+ return completions[:limit]
38
+ except requests.RequestException:
39
+ pass
40
+ return []
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser(
45
+ description="Build Rabbinic embedding benchmark dataset from Sefaria"
46
+ )
47
+ parser.add_argument(
48
+ "--max-per-text",
49
+ type=int,
50
+ default=40,
51
+ help="Maximum segments per text (default: 40)",
52
+ )
53
+ parser.add_argument(
54
+ "--total",
55
+ type=int,
56
+ default=10000,
57
+ help="Total target segments (default: 10000)",
58
+ )
59
+ parser.add_argument(
60
+ "--output",
61
+ type=str,
62
+ default="benchmark_data/benchmark.json",
63
+ help="Output file path (default: benchmark_data/benchmark.json)",
64
+ )
65
+ parser.add_argument(
66
+ "--dry-run",
67
+ action="store_true",
68
+ help="Show what would be fetched without making API calls",
69
+ )
70
+ parser.add_argument(
71
+ "--host",
72
+ type=str,
73
+ default=None,
74
+ help="Sefaria host URL (default: https://www.sefaria.org, or SEFARIA_HOST env var)",
75
+ )
76
+ parser.add_argument(
77
+ "--check-titles",
78
+ action="store_true",
79
+ help="Check all text titles against the API to verify they exist",
80
+ )
81
+
82
+ args = parser.parse_args()
83
+
84
+ # Configure Sefaria host if specified
85
+ if args.host:
86
+ set_sefaria_host(args.host)
87
+
88
+ if args.check_titles:
89
+ print("="*60)
90
+ print("Checking Text Titles Against API")
91
+ print("="*60)
92
+ host = get_sefaria_host()
93
+ print(f"\nSefaria host: {host}\n")
94
+
95
+ valid = []
96
+ invalid = []
97
+ suggestions = {}
98
+
99
+ for category_key, category_info in BENCHMARK_TEXTS.items():
100
+ category_name = category_info["category"]
101
+ print(f"\n{category_name}:")
102
+
103
+ for text in category_info["texts"]:
104
+ index = get_index_from_sefaria(text)
105
+ if index:
106
+ print(f" ✓ {text}")
107
+ valid.append(text)
108
+ else:
109
+ # Get suggestions from name API
110
+ suggested = get_name_suggestions(text, host)
111
+ suggestions[text] = suggested
112
+ if suggested:
113
+ print(f" ✗ {text} → Did you mean: {suggested[0]}?")
114
+ else:
115
+ print(f" ✗ {text}")
116
+ invalid.append(text)
117
+
118
+ print("\n" + "="*60)
119
+ print("SUMMARY")
120
+ print("="*60)
121
+ print(f"\nValid titles: {len(valid)}")
122
+ print(f"Invalid titles: {len(invalid)}")
123
+
124
+ if invalid:
125
+ print(f"\nInvalid titles that need fixing:")
126
+ for title in invalid:
127
+ suggested = suggestions.get(title, [])
128
+ if suggested:
129
+ print(f" - {title}")
130
+ print(f" Suggestions: {', '.join(suggested[:3])}")
131
+ else:
132
+ print(f" - {title} (no suggestions found)")
133
+ else:
134
+ print("\nAll titles are valid!")
135
+
136
+ return
137
+
138
+ if args.dry_run:
139
+ print("DRY RUN: Would fetch from these texts:\n")
140
+ print(f"Sefaria host: {get_sefaria_host()}")
141
+ total_texts = 0
142
+ for category_key, category_info in BENCHMARK_TEXTS.items():
143
+ print(f"\n{category_info['category']} ({category_info['language']}):")
144
+ for text in category_info["texts"]:
145
+ print(f" - {text}")
146
+ total_texts += 1
147
+ print(f"\nTotal texts: {total_texts}")
148
+ print(f"Target segments per text: {args.max_per_text}")
149
+ print(f"Total target segments: {args.total}")
150
+ return
151
+
152
+ print("="*60)
153
+ print("Building Rabbinic Embedding Benchmark Dataset")
154
+ print("="*60)
155
+ print(f"\nSettings:")
156
+ print(f" Sefaria host: {get_sefaria_host()}")
157
+ print(f" Max segments per text: {args.max_per_text}")
158
+ print(f" Total target: {args.total}")
159
+ print(f" Output: {args.output}")
160
+ print()
161
+
162
+ # Ensure output directory exists
163
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Build the dataset
166
+ pairs = build_benchmark_dataset(
167
+ output_path=args.output,
168
+ segments_per_text=args.max_per_text,
169
+ total_target=args.total,
170
+ )
171
+
172
+ # Print final statistics
173
+ stats = get_benchmark_stats(pairs)
174
+
175
+ print("\n" + "="*60)
176
+ print("BENCHMARK COMPLETE")
177
+ print("="*60)
178
+ print(f"\nFinal Statistics:")
179
+ print(f" Total pairs: {stats['total_pairs']:,}")
180
+ print(f" Categories:")
181
+ for cat, count in sorted(stats["categories"].items()):
182
+ print(f" - {cat}: {count:,}")
183
+ print(f" Average Hebrew text length: {stats['avg_he_length']:.0f} chars")
184
+ print(f" Average English text length: {stats['avg_en_length']:.0f} chars")
185
+ print(f"\nSaved to: {args.output}")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
190
+
check_token_limits.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Check token limits for benchmark data entries.
3
+
4
+ This script scans the benchmark dataset and flags entries that exceed
5
+ the 8192 token limit used by OpenAI embedding models (text-embedding-ada-002,
6
+ text-embedding-3-small, text-embedding-3-large).
7
+
8
+ Uses tiktoken with the cl100k_base encoding, which is the tokenizer used
9
+ by OpenAI's embedding models.
10
+ """
11
+
12
+ import json
13
+ import argparse
14
+ from pathlib import Path
15
+ from dataclasses import dataclass
16
+
17
+ import tiktoken
18
+
19
+
20
+ # OpenAI embedding models use cl100k_base encoding
21
+ ENCODING_NAME = "cl100k_base"
22
+ MAX_TOKENS = 8192
23
+
24
+
25
+ @dataclass
26
+ class TokenOverage:
27
+ """Represents an entry that exceeds the token limit."""
28
+ ref: str
29
+ category: str
30
+ field: str # 'he', 'en', or 'combined'
31
+ token_count: int
32
+ char_count: int
33
+ text_preview: str # First N characters of the text
34
+
35
+
36
+ def count_tokens(text: str, encoding: tiktoken.Encoding) -> int:
37
+ """Count the number of tokens in a text string."""
38
+ return len(encoding.encode(text))
39
+
40
+
41
+ def check_entry(
42
+ entry: dict,
43
+ encoding: tiktoken.Encoding,
44
+ max_tokens: int = MAX_TOKENS,
45
+ preview_length: int = 100
46
+ ) -> list[TokenOverage]:
47
+ """
48
+ Check a single entry for token limit violations.
49
+
50
+ Args:
51
+ entry: Dictionary with 'ref', 'he', 'en', 'category' keys
52
+ encoding: tiktoken encoding to use
53
+ max_tokens: Maximum allowed tokens
54
+ preview_length: Number of characters to include in preview
55
+
56
+ Returns:
57
+ List of TokenOverage objects for any fields exceeding the limit
58
+ """
59
+ overages = []
60
+
61
+ ref = entry.get("ref", "unknown")
62
+ category = entry.get("category", "unknown")
63
+
64
+ for field in ["he", "en"]:
65
+ text = entry.get(field, "")
66
+ if not text:
67
+ continue
68
+
69
+ token_count = count_tokens(text, encoding)
70
+
71
+ if token_count > max_tokens:
72
+ preview = text[:preview_length] + "..." if len(text) > preview_length else text
73
+ overages.append(TokenOverage(
74
+ ref=ref,
75
+ category=category,
76
+ field=field,
77
+ token_count=token_count,
78
+ char_count=len(text),
79
+ text_preview=preview
80
+ ))
81
+
82
+ return overages
83
+
84
+
85
+ def check_benchmark_data(
86
+ data_path: str,
87
+ max_tokens: int = MAX_TOKENS,
88
+ verbose: bool = False
89
+ ) -> tuple[list[TokenOverage], dict]:
90
+ """
91
+ Check all entries in the benchmark dataset for token limit violations.
92
+
93
+ Args:
94
+ data_path: Path to the benchmark JSON file
95
+ max_tokens: Maximum allowed tokens (default: 8192)
96
+ verbose: Print progress information
97
+
98
+ Returns:
99
+ Tuple of (list of overages, statistics dict)
100
+ """
101
+ # Load the encoding
102
+ if verbose:
103
+ print(f"Loading tokenizer: {ENCODING_NAME}")
104
+ encoding = tiktoken.get_encoding(ENCODING_NAME)
105
+
106
+ # Load the data
107
+ if verbose:
108
+ print(f"Loading data from: {data_path}")
109
+ with open(data_path, "r", encoding="utf-8") as f:
110
+ data = json.load(f)
111
+
112
+ if verbose:
113
+ print(f"Checking {len(data)} entries for token limit ({max_tokens} tokens)...")
114
+
115
+ # Check all entries
116
+ all_overages = []
117
+ token_counts_he = []
118
+ token_counts_en = []
119
+
120
+ for i, entry in enumerate(data):
121
+ if verbose and (i + 1) % 1000 == 0:
122
+ print(f" Processed {i + 1}/{len(data)} entries...")
123
+
124
+ # Count tokens for statistics
125
+ he_text = entry.get("he", "")
126
+ en_text = entry.get("en", "")
127
+
128
+ if he_text:
129
+ token_counts_he.append(count_tokens(he_text, encoding))
130
+ if en_text:
131
+ token_counts_en.append(count_tokens(en_text, encoding))
132
+
133
+ # Check for overages
134
+ overages = check_entry(entry, encoding, max_tokens)
135
+ all_overages.extend(overages)
136
+
137
+ # Compute statistics
138
+ stats = {
139
+ "total_entries": len(data),
140
+ "entries_with_overages": len(set(o.ref for o in all_overages)),
141
+ "total_overages": len(all_overages),
142
+ "he_overages": len([o for o in all_overages if o.field == "he"]),
143
+ "en_overages": len([o for o in all_overages if o.field == "en"]),
144
+ "max_tokens_checked": max_tokens,
145
+ }
146
+
147
+ if token_counts_he:
148
+ stats["he_token_stats"] = {
149
+ "min": min(token_counts_he),
150
+ "max": max(token_counts_he),
151
+ "avg": sum(token_counts_he) / len(token_counts_he),
152
+ "total_entries": len(token_counts_he),
153
+ }
154
+
155
+ if token_counts_en:
156
+ stats["en_token_stats"] = {
157
+ "min": min(token_counts_en),
158
+ "max": max(token_counts_en),
159
+ "avg": sum(token_counts_en) / len(token_counts_en),
160
+ "total_entries": len(token_counts_en),
161
+ }
162
+
163
+ return all_overages, stats
164
+
165
+
166
+ def print_report(overages: list[TokenOverage], stats: dict) -> None:
167
+ """Print a formatted report of token limit violations."""
168
+ print("\n" + "=" * 70)
169
+ print("TOKEN LIMIT CHECK REPORT")
170
+ print("=" * 70)
171
+
172
+ print(f"\nDataset Summary:")
173
+ print(f" Total entries checked: {stats['total_entries']:,}")
174
+ print(f" Token limit: {stats['max_tokens_checked']:,}")
175
+
176
+ if "he_token_stats" in stats:
177
+ he_stats = stats["he_token_stats"]
178
+ print(f"\nHebrew/Aramaic Token Statistics:")
179
+ print(f" Min tokens: {he_stats['min']:,}")
180
+ print(f" Max tokens: {he_stats['max']:,}")
181
+ print(f" Avg tokens: {he_stats['avg']:.1f}")
182
+
183
+ if "en_token_stats" in stats:
184
+ en_stats = stats["en_token_stats"]
185
+ print(f"\nEnglish Token Statistics:")
186
+ print(f" Min tokens: {en_stats['min']:,}")
187
+ print(f" Max tokens: {en_stats['max']:,}")
188
+ print(f" Avg tokens: {en_stats['avg']:.1f}")
189
+
190
+ print(f"\nOverage Summary:")
191
+ print(f" Entries exceeding limit: {stats['entries_with_overages']:,}")
192
+ print(f" Total field overages: {stats['total_overages']:,}")
193
+ print(f" - Hebrew/Aramaic fields: {stats['he_overages']:,}")
194
+ print(f" - English fields: {stats['en_overages']:,}")
195
+
196
+ if overages:
197
+ print("\n" + "-" * 70)
198
+ print("FLAGGED ENTRIES (exceeding token limit):")
199
+ print("-" * 70)
200
+
201
+ # Group by category
202
+ by_category = {}
203
+ for overage in overages:
204
+ if overage.category not in by_category:
205
+ by_category[overage.category] = []
206
+ by_category[overage.category].append(overage)
207
+
208
+ for category, category_overages in sorted(by_category.items()):
209
+ print(f"\n[{category}] - {len(category_overages)} overage(s)")
210
+ for overage in category_overages:
211
+ print(f"\n Reference: {overage.ref}")
212
+ print(f" Field: {overage.field}")
213
+ print(f" Token count: {overage.token_count:,} (limit: {stats['max_tokens_checked']:,})")
214
+ print(f" Character count: {overage.char_count:,}")
215
+ print(f" Preview: {overage.text_preview}")
216
+ else:
217
+ print("\n✓ No entries exceed the token limit!")
218
+
219
+ print("\n" + "=" * 70)
220
+
221
+
222
+ def save_report(
223
+ overages: list[TokenOverage],
224
+ stats: dict,
225
+ output_path: str
226
+ ) -> None:
227
+ """Save the report to a JSON file."""
228
+ report = {
229
+ "stats": stats,
230
+ "overages": [
231
+ {
232
+ "ref": o.ref,
233
+ "category": o.category,
234
+ "field": o.field,
235
+ "token_count": o.token_count,
236
+ "char_count": o.char_count,
237
+ "text_preview": o.text_preview,
238
+ }
239
+ for o in overages
240
+ ]
241
+ }
242
+
243
+ with open(output_path, "w", encoding="utf-8") as f:
244
+ json.dump(report, f, ensure_ascii=False, indent=2)
245
+
246
+ print(f"\nReport saved to: {output_path}")
247
+
248
+
249
+ def main():
250
+ parser = argparse.ArgumentParser(
251
+ description="Check benchmark data for entries exceeding OpenAI embedding token limits."
252
+ )
253
+ parser.add_argument(
254
+ "--data",
255
+ "-d",
256
+ type=str,
257
+ default="benchmark_data/benchmark.json",
258
+ help="Path to the benchmark JSON file (default: benchmark_data/benchmark.json)"
259
+ )
260
+ parser.add_argument(
261
+ "--max-tokens",
262
+ "-m",
263
+ type=int,
264
+ default=MAX_TOKENS,
265
+ help=f"Maximum allowed tokens (default: {MAX_TOKENS})"
266
+ )
267
+ parser.add_argument(
268
+ "--output",
269
+ "-o",
270
+ type=str,
271
+ help="Path to save JSON report (optional)"
272
+ )
273
+ parser.add_argument(
274
+ "--verbose",
275
+ "-v",
276
+ action="store_true",
277
+ help="Print progress information"
278
+ )
279
+
280
+ args = parser.parse_args()
281
+
282
+ # Check if data file exists
283
+ if not Path(args.data).exists():
284
+ print(f"Error: Data file not found: {args.data}")
285
+ return 1
286
+
287
+ # Run the check
288
+ overages, stats = check_benchmark_data(
289
+ args.data,
290
+ max_tokens=args.max_tokens,
291
+ verbose=args.verbose
292
+ )
293
+
294
+ # Print report
295
+ print_report(overages, stats)
296
+
297
+ # Save report if requested
298
+ if args.output:
299
+ save_report(overages, stats, args.output)
300
+
301
+ # Return exit code based on whether overages were found
302
+ return 1 if overages else 0
303
+
304
+
305
+ if __name__ == "__main__":
306
+ exit(main())
data_loader.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loader for Rabbinic Hebrew/Aramaic benchmark texts from Sefaria API.
3
+
4
+ Fetches parallel Hebrew/Aramaic + English text pairs across diverse categories.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import re
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import requests
15
+ import tiktoken
16
+
17
+ # Token limit for OpenAI embedding models (text-embedding-ada-002, text-embedding-3-*)
18
+ # Using cl100k_base encoding
19
+ MAX_EMBEDDING_TOKENS = 8192
20
+ _tokenizer = None
21
+
22
+
23
+ def get_tokenizer() -> tiktoken.Encoding:
24
+ """Get or create the tiktoken encoder (cached for performance)."""
25
+ global _tokenizer
26
+ if _tokenizer is None:
27
+ _tokenizer = tiktoken.get_encoding("cl100k_base")
28
+ return _tokenizer
29
+
30
+
31
+ def count_tokens(text: str) -> int:
32
+ """Count the number of tokens in a text string using OpenAI's tokenizer."""
33
+ return len(get_tokenizer().encode(text))
34
+
35
+ # Sefaria host - configurable via environment variable
36
+ # Default is the public Sefaria API
37
+ DEFAULT_SEFARIA_HOST = "https://www.sefaria.org"
38
+ SEFARIA_HOST = os.environ.get("SEFARIA_HOST", DEFAULT_SEFARIA_HOST)
39
+
40
+
41
+ def set_sefaria_host(host: str) -> None:
42
+ """Set the Sefaria host URL (e.g., 'http://localhost:8000')."""
43
+ global SEFARIA_HOST
44
+ # Remove trailing slash if present
45
+ SEFARIA_HOST = host.rstrip("/")
46
+
47
+
48
+ def get_sefaria_host() -> str:
49
+ """Get the current Sefaria host URL."""
50
+ return SEFARIA_HOST
51
+
52
+ # Text categories with confirmed English translations
53
+ BENCHMARK_TEXTS = {
54
+ "talmud_bavli": {
55
+ "category": "Talmud",
56
+ "language": "Aramaic/Hebrew",
57
+ "texts": [
58
+ "Berakhot",
59
+ "Pesachim",
60
+ "Yoma",
61
+ "Megillah",
62
+ "Chagigah",
63
+ "Ketubot",
64
+ "Gittin",
65
+ "Bava Metzia",
66
+ "Sanhedrin",
67
+ "Avodah Zarah",
68
+ "Chullin",
69
+ "Niddah",
70
+ ],
71
+ },
72
+ "talmud_yerushalmi": {
73
+ "category": "Jerusalem Talmud",
74
+ "language": "Aramaic/Hebrew",
75
+ "texts": [
76
+ "Jerusalem Talmud Berakhot",
77
+ "Jerusalem Talmud Kilayim",
78
+ "Jerusalem Talmud Terumot",
79
+ "Jerusalem Talmud Shabbat",
80
+ "Jerusalem Talmud Shekalim",
81
+ "Jerusalem Talmud Sukkah",
82
+ "Jerusalem Talmud Sotah",
83
+ "Jerusalem Talmud Nedarim",
84
+ "Jerusalem Talmud Kiddushin",
85
+ "Jerusalem Talmud Bava Kamma",
86
+ "Jerusalem Talmud Sanhedrin",
87
+ "Jerusalem Talmud Avodah Zarah",
88
+ "Jerusalem Talmud Niddah",
89
+ ],
90
+ },
91
+ "mishnah": {
92
+ "category": "Mishnah",
93
+ "language": "Rabbinic Hebrew",
94
+ "texts": [
95
+ "Mishnah Berakhot",
96
+ "Mishnah Peah",
97
+ "Mishnah Kilayim",
98
+ "Mishnah Shabbat",
99
+ "Mishnah Pesachim",
100
+ "Mishnah Sukkah",
101
+ "Mishnah Taanit",
102
+ "Mishnah Chagigah",
103
+ "Mishnah Yevamot",
104
+ "Mishnah Sotah",
105
+ "Mishnah Kiddushin",
106
+ "Mishnah Bava Kamma",
107
+ "Mishnah Sanhedrin",
108
+ "Mishnah Eduyot",
109
+ "Mishnah Avot",
110
+ "Mishnah Zevachim",
111
+ "Mishnah Chullin",
112
+ "Mishnah Tamid",
113
+ "Mishnah Kelim",
114
+ "Mishnah Parah",
115
+ "Mishnah Niddah",
116
+ ],
117
+ },
118
+ "midrash_rabbah": {
119
+ "category": "Midrash Rabbah",
120
+ "language": "Hebrew/Aramaic",
121
+ "texts": [
122
+ "Bereishit Rabbah",
123
+ "Shemot Rabbah",
124
+ "Vayikra Rabbah",
125
+ "Bamidbar Rabbah",
126
+ "Devarim Rabbah",
127
+ "Shir HaShirim Rabbah",
128
+ "Ruth Rabbah",
129
+ "Eichah Rabbah",
130
+ "Kohelet Rabbah",
131
+ "Esther Rabbah",
132
+ ],
133
+ },
134
+ "tanakh_commentary": {
135
+ "category": "Tanakh Commentary",
136
+ "language": "Hebrew",
137
+ "texts": [
138
+ "Rashi on Genesis",
139
+ "Rashi on Exodus",
140
+ "Rashi on Leviticus",
141
+ "Rashi on Numbers",
142
+ "Rashi on Deuteronomy",
143
+ "Ramban on Genesis",
144
+ "Ramban on Exodus",
145
+ "Ramban on Leviticus",
146
+ "Ramban on Numbers",
147
+ "Ramban on Deuteronomy",
148
+ "Radak on Genesis",
149
+ "Akeidat Yitzchak",
150
+ "Rabbeinu Behaye, Bereshit",
151
+ "Rabbeinu Behaye, Shemot",
152
+ "Rabbeinu Behaye, Vayikra",
153
+ "Rabbeinu Behaye, Bamidbar",
154
+ "Rabbeinu Behaye, Devarim",
155
+ ],
156
+ },
157
+ "hasidic_kabbalistic": {
158
+ "category": "Hasidic/Kabbalistic",
159
+ "language": "Hebrew",
160
+ "texts": [
161
+ "Likutei Moharan",
162
+ "Tomer Devorah",
163
+ "Or Neerav, PART I",
164
+ "Or Neerav, PART II",
165
+ "Or Neerav, PART III",
166
+ "Shekel HaKodesh, On Abstinence",
167
+ "Shekel HaKodesh, On Wisdom",
168
+ "Kalach Pitchei Chokhmah",
169
+ ],
170
+ },
171
+ "halacha": {
172
+ "category": "Halacha",
173
+ "language": "Hebrew",
174
+ "texts": [
175
+ "Sefer HaChinukh",
176
+ "Shev Shmateta, Introduction",
177
+ "Mishneh Torah, Human Dispositions",
178
+ "Sefer Yesodei HaTorah",
179
+ ],
180
+ },
181
+ "philosophy": {
182
+ "category": "Philosophy",
183
+ "language": "Hebrew",
184
+ "texts": [
185
+ "Sefer HaIkkarim, Maamar 1",
186
+ "Sefer HaIkkarim, Maamar 2",
187
+ "Sefer HaIkkarim, Maamar 3",
188
+ "Guide for the Perplexed, Part 1",
189
+ "Guide for the Perplexed, Part 2",
190
+ "Guide for the Perplexed, Part 3",
191
+ ],
192
+ },
193
+ "targum": {
194
+ "category": "Targum",
195
+ "language": "Aramaic",
196
+ "texts": [
197
+ "Aramaic Targum to Song of Songs",
198
+ ],
199
+ },
200
+ "mussar": {
201
+ "category": "Mussar/Ethics",
202
+ "language": "Hebrew",
203
+ "texts": [
204
+ "Iggeret HaRamban",
205
+ "Shulchan Shel Arba",
206
+ "Chafetz Chaim",
207
+ "Yesod HaYirah, On Endurance",
208
+ "Yesod HaYirah, On Humility",
209
+ "Kav HaYashar",
210
+ ],
211
+ },
212
+ }
213
+
214
+
215
+ def strip_html(text: str) -> str:
216
+ """
217
+ Remove HTML tags from text.
218
+
219
+ Some tags are dropped completely with their content:
220
+ - <sup class="footnote-marker">...</sup>
221
+ - <i class="footnote"...>...</i>
222
+
223
+ Other tags are stripped but their inner content is preserved.
224
+ """
225
+ # First, remove footnote markers (simple, no nesting issues)
226
+ clean = re.sub(r'<sup[^>]*class="footnote-marker"[^>]*>.*?</sup>', '', text, flags=re.DOTALL)
227
+
228
+ # Remove footnotes with nested <i> tags - need to handle nesting
229
+ # Strategy: find footnote start, then count <i> and </i> to find matching close
230
+ clean = _remove_footnote_tags(clean)
231
+
232
+ # Then strip remaining HTML tags (keeping their content)
233
+ clean = re.sub(r"<[^>]+>", "", clean)
234
+
235
+ # Clean up extra whitespace
236
+ clean = re.sub(r"\s+", " ", clean).strip()
237
+ return clean
238
+
239
+
240
+ def _remove_footnote_tags(text: str) -> str:
241
+ """Remove <i class="footnote"...>...</i> tags, handling nested <i> tags."""
242
+ result = []
243
+ i = 0
244
+
245
+ while i < len(text):
246
+ # Look for footnote opening tag
247
+ match = re.match(r'<i[^>]*class="footnote"[^>]*>', text[i:], flags=re.IGNORECASE)
248
+ if match:
249
+ # Found a footnote, now find the matching </i>
250
+ start = i + match.end()
251
+ depth = 1
252
+ j = start
253
+
254
+ while j < len(text) and depth > 0:
255
+ if text[j:j+3].lower() == '<i ' or text[j:j+3].lower() == '<i>':
256
+ depth += 1
257
+ j += 1
258
+ elif text[j:j+4].lower() == '</i>':
259
+ depth -= 1
260
+ if depth == 0:
261
+ # Skip past the closing </i>
262
+ j += 4
263
+ break
264
+ j += 1
265
+ else:
266
+ j += 1
267
+
268
+ # Skip the entire footnote (from i to j)
269
+ i = j
270
+ else:
271
+ result.append(text[i])
272
+ i += 1
273
+
274
+ return ''.join(result)
275
+
276
+
277
+ def extract_bold_only(text: str) -> str:
278
+ """
279
+ Extract only content within <b>...</b> tags, for Talmud Bavli.
280
+
281
+ The Steinsaltz English has bold for actual translation and non-bold for
282
+ elucidation. We only want the translation.
283
+
284
+ Example:
285
+ "<b>The Rabbis say:</b> The time for... is <b>until midnight.</b>"
286
+ -> "The Rabbis say: until midnight."
287
+ """
288
+ # Find all content within <b>...</b> tags
289
+ bold_parts = re.findall(r'<b>(.*?)</b>', text, flags=re.DOTALL)
290
+
291
+ if not bold_parts:
292
+ # No bold tags found, fall back to regular strip
293
+ return strip_html(text)
294
+
295
+ # Strip any nested HTML from each bold part and join with spaces
296
+ cleaned_parts = [strip_html(part) for part in bold_parts]
297
+
298
+ # Join parts, ensuring proper spacing
299
+ result = ' '.join(cleaned_parts)
300
+
301
+ # Clean up extra whitespace
302
+ result = re.sub(r"\s+", " ", result).strip()
303
+ return result
304
+
305
+
306
+ def get_text_from_sefaria(ref: str, retries: int = 3) -> Optional[dict]:
307
+ """
308
+ Fetch a text from Sefaria API.
309
+
310
+ Args:
311
+ ref: Sefaria reference string (e.g., "Berakhot.2a")
312
+ retries: Number of retry attempts
313
+
314
+ Returns:
315
+ Dict with 'he' (Hebrew/Aramaic) and 'en' (English) texts, or None if failed/error
316
+ """
317
+ url = f"{SEFARIA_HOST}/api/texts/{ref}"
318
+ params = {"context": 0}
319
+
320
+ for attempt in range(retries):
321
+ try:
322
+ response = requests.get(url, params=params, timeout=30)
323
+ if response.status_code == 200:
324
+ data = response.json()
325
+ # Check if response contains an error
326
+ if "error" in data:
327
+ return None
328
+ return data
329
+ elif response.status_code == 429:
330
+ # Rate limited, wait and retry
331
+ time.sleep(2 ** attempt)
332
+ else:
333
+ return None
334
+ except requests.RequestException:
335
+ if attempt < retries - 1:
336
+ time.sleep(1)
337
+ continue
338
+ return None
339
+
340
+
341
+ def get_index_from_sefaria(title: str) -> Optional[dict]:
342
+ """
343
+ Get index/structure information for a text.
344
+
345
+ Args:
346
+ title: The title of the text
347
+
348
+ Returns:
349
+ Index data or None if failed or text not found
350
+ """
351
+ url = f"{SEFARIA_HOST}/api/index/{title}"
352
+ try:
353
+ response = requests.get(url, timeout=30)
354
+ if response.status_code == 200:
355
+ data = response.json()
356
+ # Check if response contains an error
357
+ if "error" in data:
358
+ return None
359
+ return data
360
+ except requests.RequestException:
361
+ pass
362
+ return None
363
+
364
+
365
+ def extract_parallel_segments(data: dict, ref: str, category: str = "") -> list[dict]:
366
+ """
367
+ Extract parallel Hebrew/English segments from API response.
368
+
369
+ Args:
370
+ data: API response data
371
+ ref: The reference string
372
+ category: Category name (used for special handling, e.g., "Talmud")
373
+
374
+ Returns:
375
+ List of dicts with 'ref', 'he', 'en' keys
376
+ """
377
+ segments = []
378
+
379
+ he_text = data.get("he", [])
380
+ en_text = data.get("text", [])
381
+
382
+ # Handle nested arrays (common in Talmud)
383
+ if he_text and isinstance(he_text, list):
384
+ # Flatten if nested
385
+ if he_text and isinstance(he_text[0], list):
386
+ he_flat = []
387
+ en_flat = []
388
+ for i, (he_seg, en_seg) in enumerate(zip(he_text, en_text)):
389
+ if isinstance(he_seg, list):
390
+ he_flat.extend(he_seg)
391
+ en_flat.extend(en_seg if isinstance(en_seg, list) else [en_seg])
392
+ else:
393
+ he_flat.append(he_seg)
394
+ en_flat.append(en_seg)
395
+ he_text = he_flat
396
+ en_text = en_flat
397
+
398
+ # Handle single string responses
399
+ if isinstance(he_text, str):
400
+ he_text = [he_text]
401
+ if isinstance(en_text, str):
402
+ en_text = [en_text]
403
+
404
+ # For Talmud Bavli, extract only bold text (actual translation, not elucidation)
405
+ is_bavli = category == "Talmud"
406
+
407
+ # Pair up segments
408
+ for i, (he, en) in enumerate(zip(he_text, en_text)):
409
+ if he and en:
410
+ he_clean = strip_html(str(he)) if he else ""
411
+ # Use bold-only extraction for Bavli English
412
+ if is_bavli:
413
+ en_clean = extract_bold_only(str(en)) if en else ""
414
+ else:
415
+ en_clean = strip_html(str(en)) if en else ""
416
+
417
+ # Skip empty or very short segments
418
+ if len(he_clean) > 10 and len(en_clean) > 10:
419
+ # Check token limits for OpenAI embedding models
420
+ he_tokens = count_tokens(he_clean)
421
+ en_tokens = count_tokens(en_clean)
422
+
423
+ if he_tokens > MAX_EMBEDDING_TOKENS:
424
+ print(f" Skipping {ref}:{i+1} - Hebrew text exceeds token limit ({he_tokens} > {MAX_EMBEDDING_TOKENS})")
425
+ continue
426
+ if en_tokens > MAX_EMBEDDING_TOKENS:
427
+ print(f" Skipping {ref}:{i+1} - English text exceeds token limit ({en_tokens} > {MAX_EMBEDDING_TOKENS})")
428
+ continue
429
+
430
+ segments.append({
431
+ "ref": f"{ref}:{i+1}" if ":" not in ref else ref,
432
+ "he": he_clean,
433
+ "en": en_clean,
434
+ })
435
+
436
+ return segments
437
+
438
+
439
+ def fetch_text_pairs(
440
+ text_title: str,
441
+ category: str,
442
+ max_segments: int = 500,
443
+ delay: float = 0.5
444
+ ) -> list[dict]:
445
+ """
446
+ Fetch parallel text pairs for a given text.
447
+
448
+ Args:
449
+ text_title: Title of the text to fetch
450
+ category: Category name for metadata
451
+ max_segments: Maximum segments to fetch per text
452
+ delay: Delay between API calls (rate limiting)
453
+
454
+ Returns:
455
+ List of segment dicts with ref, he, en, category
456
+ """
457
+ pairs = []
458
+
459
+ # Get text index to understand structure
460
+ index = get_index_from_sefaria(text_title)
461
+ if not index:
462
+ print(f" Could not get index for {text_title}")
463
+ return pairs
464
+
465
+ # Determine refs to fetch based on text structure
466
+ schema = index.get("schema", {})
467
+
468
+ # For simple texts, just fetch the whole thing
469
+ if schema.get("nodeType") == "JaggedArrayNode":
470
+ depth = schema.get("depth", 2)
471
+ address_types = schema.get("addressTypes", [])
472
+
473
+ # Check if this uses Talmud daf notation (2a, 2b, etc.)
474
+ uses_talmud_daf = address_types and address_types[0] == "Talmud"
475
+
476
+ if uses_talmud_daf:
477
+ # Talmud-style structure with daf notation (e.g., Berakhot.2a)
478
+ # Start from daf 3 for Jerusalem Talmud to avoid overlap with Bavli
479
+ start_daf = 3 if category == "Jerusalem Talmud" else 2
480
+ # Fetch daf by daf
481
+ done = False
482
+ for daf_num in range(start_daf, 200):
483
+ if len(pairs) >= max_segments or done:
484
+ break
485
+
486
+ for side in ["a", "b"]:
487
+ if len(pairs) >= max_segments:
488
+ break
489
+
490
+ ref = f"{text_title}.{daf_num}{side}"
491
+ data = get_text_from_sefaria(ref)
492
+
493
+ # None means API error (daf doesn't exist)
494
+ if data is None:
495
+ if side == "a":
496
+ done = True # Daf doesn't exist, we're done with tractate
497
+ break
498
+
499
+ if not data.get("he"):
500
+ continue # Empty side, try next
501
+
502
+ segments = extract_parallel_segments(data, ref, category)
503
+ for seg in segments:
504
+ seg["category"] = category
505
+ pairs.extend(segments)
506
+
507
+ time.sleep(delay)
508
+
509
+ elif depth == 1:
510
+ # Single-level structure (e.g., Iggeret HaRamban - just paragraphs)
511
+ # Fetch the whole text at once
512
+ data = get_text_from_sefaria(text_title)
513
+ if data and data.get("he"):
514
+ segments = extract_parallel_segments(data, text_title, category)
515
+ for seg in segments:
516
+ seg["category"] = category
517
+ pairs.extend(segments)
518
+
519
+ elif depth == 2:
520
+ # Two-level structure (e.g., Mishnah chapter:verse)
521
+ # Start from chapter 2 for Mishnah to avoid overlap with Talmud
522
+ start_chapter = 2 if category == "Mishnah" else 1
523
+ consecutive_empty = 0
524
+ # Fetch chapter by chapter
525
+ for chapter in range(start_chapter, 200): # Reasonable upper bound
526
+ if len(pairs) >= max_segments:
527
+ break
528
+
529
+ ref = f"{text_title}.{chapter}"
530
+ data = get_text_from_sefaria(ref)
531
+
532
+ # None means API error (ref doesn't exist)
533
+ if data is None:
534
+ break
535
+
536
+ # Empty array means chapter exists but has no content
537
+ if not data.get("he"):
538
+ consecutive_empty += 1
539
+ if consecutive_empty >= 5:
540
+ break # Probably past end of book
541
+ time.sleep(delay)
542
+ continue
543
+
544
+ consecutive_empty = 0
545
+ segments = extract_parallel_segments(data, ref, category)
546
+ for seg in segments:
547
+ seg["category"] = category
548
+ pairs.extend(segments)
549
+
550
+ time.sleep(delay)
551
+
552
+ elif depth >= 3:
553
+ # Three+ level structure (e.g., commentary chapter:verse:comment)
554
+ # Fetch chapter.verse by chapter.verse
555
+ # For Jerusalem Talmud, start from 1.3 to avoid overlap with Bavli
556
+ start_verse = 3 if category == "Jerusalem Talmud" else 1
557
+ consecutive_empty_chapters = 0
558
+ for chapter in range(1, 200):
559
+ if len(pairs) >= max_segments:
560
+ break
561
+
562
+ chapter_had_content = False
563
+ # Use start_verse only for first chapter
564
+ first_verse = start_verse if chapter == 1 else 1
565
+ for verse in range(first_verse, 100):
566
+ if len(pairs) >= max_segments:
567
+ break
568
+
569
+ ref = f"{text_title}.{chapter}.{verse}"
570
+ data = get_text_from_sefaria(ref)
571
+
572
+ # None means API error (ref doesn't exist)
573
+ if data is None:
574
+ break # No more verses in this chapter
575
+
576
+ # Empty array means verse exists but has no content
577
+ if not data.get("he"):
578
+ continue
579
+
580
+ chapter_had_content = True
581
+ segments = extract_parallel_segments(data, ref, category)
582
+ for seg in segments:
583
+ seg["category"] = category
584
+ pairs.extend(segments)
585
+
586
+ time.sleep(delay)
587
+
588
+ if not chapter_had_content:
589
+ consecutive_empty_chapters += 1
590
+ if consecutive_empty_chapters >= 5:
591
+ break # Probably past end of book
592
+ else:
593
+ consecutive_empty_chapters = 0
594
+
595
+ else:
596
+ # Complex structure (SchemaNode) - try different ref patterns
597
+ # First try simple numeric refs (works for Sefer HaChinukh style)
598
+ consecutive_empty = 0
599
+ for section in range(1, 1000):
600
+ if len(pairs) >= max_segments:
601
+ break
602
+
603
+ ref = f"{text_title}.{section}"
604
+ data = get_text_from_sefaria(ref)
605
+
606
+ if data is None:
607
+ break
608
+
609
+ if not data.get("he"):
610
+ consecutive_empty += 1
611
+ if consecutive_empty >= 5:
612
+ break
613
+ time.sleep(delay)
614
+ continue
615
+
616
+ consecutive_empty = 0
617
+ segments = extract_parallel_segments(data, ref, category)
618
+ for seg in segments:
619
+ seg["category"] = category
620
+ pairs.extend(segments)
621
+
622
+ time.sleep(delay)
623
+
624
+ # If we haven't reached max_segments, try chapter.verse style refs (commentary pattern)
625
+ if len(pairs) < max_segments:
626
+ consecutive_empty = 0
627
+ for chapter in range(1, 100):
628
+ if len(pairs) >= max_segments:
629
+ break
630
+
631
+ chapter_had_content = False
632
+ for verse in range(1, 50):
633
+ if len(pairs) >= max_segments:
634
+ break
635
+
636
+ ref = f"{text_title}.{chapter}.{verse}"
637
+ data = get_text_from_sefaria(ref)
638
+
639
+ if data is None:
640
+ break # This verse doesn't exist, try next chapter
641
+
642
+ if data.get("he"):
643
+ chapter_had_content = True
644
+ consecutive_empty = 0
645
+ segments = extract_parallel_segments(data, ref, category)
646
+ for seg in segments:
647
+ seg["category"] = category
648
+ pairs.extend(segments)
649
+
650
+ time.sleep(delay)
651
+
652
+ if not chapter_had_content:
653
+ consecutive_empty += 1
654
+ if consecutive_empty >= 5:
655
+ break
656
+
657
+ return pairs[:max_segments]
658
+
659
+
660
+ def build_benchmark_dataset(
661
+ output_path: str = "benchmark_data/benchmark.json",
662
+ segments_per_text: int = 200,
663
+ total_target: int = 10000,
664
+ ) -> list[dict]:
665
+ """
666
+ Build the full benchmark dataset from all configured texts.
667
+
668
+ Args:
669
+ output_path: Path to save the benchmark JSON
670
+ segments_per_text: Target segments per text
671
+ total_target: Overall target segment count
672
+
673
+ Returns:
674
+ List of all benchmark pairs
675
+ """
676
+ all_pairs = []
677
+
678
+ for category_key, category_info in BENCHMARK_TEXTS.items():
679
+ category_name = category_info["category"]
680
+ texts = category_info["texts"]
681
+
682
+ print(f"\n{'='*60}")
683
+ print(f"Processing category: {category_name}")
684
+ print(f"{'='*60}")
685
+
686
+ for text_title in texts:
687
+ if len(all_pairs) >= total_target:
688
+ break
689
+
690
+ print(f"\nFetching: {text_title}")
691
+
692
+ pairs = fetch_text_pairs(
693
+ text_title,
694
+ category_name,
695
+ max_segments=segments_per_text,
696
+ )
697
+
698
+ print(f" Got {len(pairs)} pairs")
699
+ all_pairs.extend(pairs)
700
+
701
+ if len(all_pairs) >= total_target:
702
+ break
703
+
704
+ # Save to file
705
+ output_file = Path(output_path)
706
+ output_file.parent.mkdir(parents=True, exist_ok=True)
707
+
708
+ with open(output_file, "w", encoding="utf-8") as f:
709
+ json.dump(all_pairs, f, ensure_ascii=False, indent=2)
710
+
711
+ print(f"\n{'='*60}")
712
+ print(f"Total pairs collected: {len(all_pairs)}")
713
+ print(f"Saved to: {output_path}")
714
+
715
+ # Save stats to markdown file
716
+ stats = get_benchmark_stats(all_pairs)
717
+ save_stats_markdown(stats, output_path)
718
+
719
+ return all_pairs
720
+
721
+
722
+ def load_benchmark_dataset(path: str = "benchmark_data/benchmark.json") -> list[dict]:
723
+ """
724
+ Load the pre-cached benchmark dataset.
725
+
726
+ Args:
727
+ path: Path to the benchmark JSON file
728
+
729
+ Returns:
730
+ List of benchmark pairs
731
+ """
732
+ with open(path, "r", encoding="utf-8") as f:
733
+ return json.load(f)
734
+
735
+
736
+ def get_benchmark_stats(pairs: list[dict]) -> dict:
737
+ """
738
+ Get statistics about the benchmark dataset.
739
+
740
+ Args:
741
+ pairs: List of benchmark pairs
742
+
743
+ Returns:
744
+ Dict with category counts and other stats
745
+ """
746
+ from collections import Counter
747
+
748
+ categories = Counter(p["category"] for p in pairs)
749
+
750
+ he_lengths = [len(p["he"]) for p in pairs]
751
+ en_lengths = [len(p["en"]) for p in pairs]
752
+
753
+ return {
754
+ "total_pairs": len(pairs),
755
+ "categories": dict(categories),
756
+ "avg_he_length": sum(he_lengths) / len(he_lengths) if he_lengths else 0,
757
+ "avg_en_length": sum(en_lengths) / len(en_lengths) if en_lengths else 0,
758
+ }
759
+
760
+
761
+ def save_stats_markdown(stats: dict, data_path: str) -> str:
762
+ """
763
+ Save benchmark statistics to a markdown file alongside the data.
764
+
765
+ Args:
766
+ stats: Statistics dict from get_benchmark_stats()
767
+ data_path: Path to the data file (used to derive stats file path)
768
+
769
+ Returns:
770
+ Path to the saved markdown file
771
+ """
772
+ from datetime import datetime
773
+
774
+ # Derive markdown path from data path
775
+ data_file = Path(data_path)
776
+ stats_path = data_file.with_suffix(".stats.md")
777
+
778
+ # Build markdown content
779
+ lines = [
780
+ "# Benchmark Dataset Statistics",
781
+ "",
782
+ f"**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
783
+ "",
784
+ "## Summary",
785
+ "",
786
+ f"- **Total pairs:** {stats['total_pairs']:,}",
787
+ f"- **Average Hebrew length:** {stats['avg_he_length']:.0f} chars",
788
+ f"- **Average English length:** {stats['avg_en_length']:.0f} chars",
789
+ "",
790
+ "## Category Breakdown",
791
+ "",
792
+ "| Category | Count |",
793
+ "|----------|-------|",
794
+ ]
795
+
796
+ # Sort categories by count (descending)
797
+ sorted_categories = sorted(
798
+ stats["categories"].items(),
799
+ key=lambda x: x[1],
800
+ reverse=True
801
+ )
802
+
803
+ for category, count in sorted_categories:
804
+ lines.append(f"| {category} | {count:,} |")
805
+
806
+ lines.append("")
807
+
808
+ # Write to file
809
+ with open(stats_path, "w", encoding="utf-8") as f:
810
+ f.write("\n".join(lines))
811
+
812
+ print(f"Stats saved to: {stats_path}")
813
+ return str(stats_path)
814
+
815
+
816
+ if __name__ == "__main__":
817
+ # Build the benchmark dataset
818
+ print("Building Rabbinic Hebrew/Aramaic benchmark dataset...")
819
+ pairs = build_benchmark_dataset()
820
+
821
+ # Print stats
822
+ stats = get_benchmark_stats(pairs)
823
+ print(f"\nDataset Statistics:")
824
+ print(f" Total pairs: {stats['total_pairs']}")
825
+ print(f" Categories: {stats['categories']}")
826
+ print(f" Avg Hebrew length: {stats['avg_he_length']:.0f} chars")
827
+ print(f" Avg English length: {stats['avg_en_length']:.0f} chars")
828
+
evaluation.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-lingual retrieval evaluation for Rabbinic embedding benchmark.
3
+
4
+ Computes retrieval metrics to measure how well embedding models align
5
+ Hebrew/Aramaic source texts with their English translations.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+ import numpy as np
11
+
12
+
13
+ @dataclass
14
+ class EvaluationResults:
15
+ """Container for evaluation results."""
16
+
17
+ model_id: str
18
+ model_name: str
19
+
20
+ # Core retrieval metrics
21
+ recall_at_1: float
22
+ recall_at_5: float
23
+ recall_at_10: float
24
+ mrr: float # Mean Reciprocal Rank
25
+
26
+ # Additional metrics
27
+ bitext_accuracy: float # True pair vs random pair classification
28
+ avg_true_pair_similarity: float
29
+ avg_random_pair_similarity: float
30
+
31
+ # Metadata
32
+ num_pairs: int
33
+ categories: dict[str, int]
34
+
35
+ def to_dict(self) -> dict:
36
+ """Convert to dictionary for JSON serialization."""
37
+ return {
38
+ "model_id": self.model_id,
39
+ "model_name": self.model_name,
40
+ "recall_at_1": self.recall_at_1,
41
+ "recall_at_5": self.recall_at_5,
42
+ "recall_at_10": self.recall_at_10,
43
+ "mrr": self.mrr,
44
+ "bitext_accuracy": self.bitext_accuracy,
45
+ "avg_true_pair_similarity": self.avg_true_pair_similarity,
46
+ "avg_random_pair_similarity": self.avg_random_pair_similarity,
47
+ "num_pairs": self.num_pairs,
48
+ "categories": self.categories,
49
+ }
50
+
51
+ @classmethod
52
+ def from_dict(cls, data: dict) -> "EvaluationResults":
53
+ """Create from dictionary."""
54
+ return cls(**data)
55
+
56
+
57
+ def compute_similarity_matrix(
58
+ query_embeddings: np.ndarray,
59
+ passage_embeddings: np.ndarray,
60
+ ) -> np.ndarray:
61
+ """
62
+ Compute cosine similarity matrix between queries and passages.
63
+
64
+ Assumes embeddings are already L2-normalized.
65
+
66
+ Args:
67
+ query_embeddings: (N, D) array of query embeddings
68
+ passage_embeddings: (M, D) array of passage embeddings
69
+
70
+ Returns:
71
+ (N, M) similarity matrix
72
+ """
73
+ return np.dot(query_embeddings, passage_embeddings.T)
74
+
75
+
76
+ def compute_retrieval_metrics(
77
+ similarity_matrix: np.ndarray,
78
+ k_values: list[int] = [1, 5, 10],
79
+ ) -> dict[str, float]:
80
+ """
81
+ Compute retrieval metrics from similarity matrix.
82
+
83
+ Assumes the correct match for query i is passage i (diagonal).
84
+
85
+ Args:
86
+ similarity_matrix: (N, N) similarity matrix where diagonal is true matches
87
+ k_values: List of k values for Recall@k
88
+
89
+ Returns:
90
+ Dict with recall@k and mrr values
91
+ """
92
+ n = similarity_matrix.shape[0]
93
+
94
+ # Get rankings for each query
95
+ # Negate to sort descending (highest similarity first)
96
+ rankings = np.argsort(-similarity_matrix, axis=1)
97
+
98
+ # Find rank of true match (diagonal) for each query
99
+ true_ranks = np.zeros(n, dtype=int)
100
+ for i in range(n):
101
+ # Find position of index i in the ranking for query i
102
+ true_ranks[i] = np.where(rankings[i] == i)[0][0]
103
+
104
+ results = {}
105
+
106
+ # Recall@k: fraction where true match is in top k
107
+ for k in k_values:
108
+ recall = np.mean(true_ranks < k)
109
+ results[f"recall_at_{k}"] = float(recall)
110
+
111
+ # MRR: Mean Reciprocal Rank
112
+ reciprocal_ranks = 1.0 / (true_ranks + 1) # +1 because ranks are 0-indexed
113
+ results["mrr"] = float(np.mean(reciprocal_ranks))
114
+
115
+ return results
116
+
117
+
118
+ def compute_bitext_accuracy(
119
+ similarity_matrix: np.ndarray,
120
+ num_negatives: int = 10,
121
+ ) -> tuple[float, float, float]:
122
+ """
123
+ Compute bitext mining accuracy.
124
+
125
+ For each true pair, sample random negative pairs and check if the model
126
+ correctly ranks the true pair higher.
127
+
128
+ Args:
129
+ similarity_matrix: (N, N) similarity matrix
130
+ num_negatives: Number of negative samples per true pair
131
+
132
+ Returns:
133
+ Tuple of (accuracy, avg_true_sim, avg_random_sim)
134
+ """
135
+ n = similarity_matrix.shape[0]
136
+
137
+ # True pair similarities (diagonal)
138
+ true_similarities = np.diag(similarity_matrix)
139
+
140
+ # Sample random negative pairs
141
+ correct = 0
142
+ total = 0
143
+ random_sims = []
144
+
145
+ rng = np.random.default_rng(42)
146
+
147
+ for i in range(n):
148
+ true_sim = true_similarities[i]
149
+
150
+ # Sample random passage indices (not the true match)
151
+ neg_indices = rng.choice(
152
+ [j for j in range(n) if j != i],
153
+ size=min(num_negatives, n - 1),
154
+ replace=False,
155
+ )
156
+
157
+ for j in neg_indices:
158
+ neg_sim = similarity_matrix[i, j]
159
+ random_sims.append(neg_sim)
160
+
161
+ if true_sim > neg_sim:
162
+ correct += 1
163
+ total += 1
164
+
165
+ accuracy = correct / total if total > 0 else 0.0
166
+ avg_true = float(np.mean(true_similarities))
167
+ avg_random = float(np.mean(random_sims)) if random_sims else 0.0
168
+
169
+ return accuracy, avg_true, avg_random
170
+
171
+
172
+ def evaluate_model(
173
+ model,
174
+ benchmark_pairs: list[dict],
175
+ batch_size: int = 32,
176
+ max_pairs: Optional[int] = None,
177
+ progress_callback=None,
178
+ ) -> EvaluationResults:
179
+ """
180
+ Run full evaluation of a model on the benchmark.
181
+
182
+ Args:
183
+ model: EmbeddingModel instance
184
+ benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
185
+ batch_size: Batch size for encoding
186
+ max_pairs: Maximum pairs to evaluate (for faster testing)
187
+ progress_callback: Optional callback for progress updates
188
+
189
+ Returns:
190
+ EvaluationResults with all metrics
191
+ """
192
+ # Use streaming version and return final result
193
+ result = None
194
+ for item in evaluate_model_streaming(model, benchmark_pairs, batch_size, max_pairs):
195
+ if isinstance(item, str):
196
+ if progress_callback:
197
+ progress_callback(0.5, item)
198
+ else:
199
+ result = item
200
+ return result
201
+
202
+
203
+ def evaluate_model_streaming(
204
+ model,
205
+ benchmark_pairs: list[dict],
206
+ batch_size: int = 32,
207
+ max_pairs: Optional[int] = None,
208
+ ):
209
+ """
210
+ Run evaluation with streaming progress updates.
211
+
212
+ Yields progress strings during encoding, then yields final EvaluationResults.
213
+
214
+ Args:
215
+ model: EmbeddingModel instance
216
+ benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
217
+ batch_size: Batch size for encoding
218
+ max_pairs: Maximum pairs to evaluate (for faster testing)
219
+
220
+ Yields:
221
+ Progress strings, then final EvaluationResults
222
+ """
223
+ from collections import Counter
224
+
225
+ # Optionally limit pairs
226
+ if max_pairs and len(benchmark_pairs) > max_pairs:
227
+ benchmark_pairs = benchmark_pairs[:max_pairs]
228
+
229
+ # Extract texts
230
+ he_texts = [p["he"] for p in benchmark_pairs]
231
+ en_texts = [p["en"] for p in benchmark_pairs]
232
+ categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs)
233
+ n_total = len(he_texts)
234
+
235
+ # Encode Hebrew texts in batches with progress
236
+ yield f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}"
237
+ he_embeddings_list = []
238
+ for i in range(0, len(he_texts), batch_size):
239
+ batch = he_texts[i:i + batch_size]
240
+ batch_emb = model.encode(
241
+ batch,
242
+ is_query=True,
243
+ batch_size=batch_size,
244
+ show_progress=False,
245
+ )
246
+ he_embeddings_list.append(batch_emb)
247
+ done = min(i + batch_size, len(he_texts))
248
+ yield f"⏳ Encoding Hebrew/Aramaic texts: {done:,}/{n_total:,}"
249
+
250
+ he_embeddings = np.vstack(he_embeddings_list)
251
+
252
+ # Encode English texts in batches with progress
253
+ yield f"⏳ Encoding English texts: 0/{n_total:,}"
254
+ en_embeddings_list = []
255
+ for i in range(0, len(en_texts), batch_size):
256
+ batch = en_texts[i:i + batch_size]
257
+ batch_emb = model.encode(
258
+ batch,
259
+ is_query=False,
260
+ batch_size=batch_size,
261
+ show_progress=False,
262
+ )
263
+ en_embeddings_list.append(batch_emb)
264
+ done = min(i + batch_size, len(en_texts))
265
+ yield f"⏳ Encoding English texts: {done:,}/{n_total:,}"
266
+
267
+ en_embeddings = np.vstack(en_embeddings_list)
268
+
269
+ yield "⏳ Computing similarity matrix..."
270
+ similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings)
271
+
272
+ yield "⏳ Computing retrieval metrics..."
273
+ retrieval_metrics = compute_retrieval_metrics(similarity_matrix)
274
+
275
+ yield "⏳ Computing bitext accuracy..."
276
+ bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy(
277
+ similarity_matrix
278
+ )
279
+
280
+ # Yield final results
281
+ yield EvaluationResults(
282
+ model_id=model.model_id,
283
+ model_name=model.name,
284
+ recall_at_1=retrieval_metrics["recall_at_1"],
285
+ recall_at_5=retrieval_metrics["recall_at_5"],
286
+ recall_at_10=retrieval_metrics["recall_at_10"],
287
+ mrr=retrieval_metrics["mrr"],
288
+ bitext_accuracy=bitext_acc,
289
+ avg_true_pair_similarity=avg_true_sim,
290
+ avg_random_pair_similarity=avg_random_sim,
291
+ num_pairs=len(benchmark_pairs),
292
+ categories=dict(categories),
293
+ )
294
+
295
+
296
+ def evaluate_by_category(
297
+ model,
298
+ benchmark_pairs: list[dict],
299
+ batch_size: int = 32,
300
+ ) -> dict[str, EvaluationResults]:
301
+ """
302
+ Run evaluation broken down by category.
303
+
304
+ Args:
305
+ model: EmbeddingModel instance
306
+ benchmark_pairs: List of benchmark pairs
307
+ batch_size: Batch size for encoding
308
+
309
+ Returns:
310
+ Dict mapping category name to EvaluationResults
311
+ """
312
+ from collections import defaultdict
313
+
314
+ # Group pairs by category
315
+ by_category = defaultdict(list)
316
+ for pair in benchmark_pairs:
317
+ category = pair.get("category", "Unknown")
318
+ by_category[category].append(pair)
319
+
320
+ results = {}
321
+ for category, pairs in by_category.items():
322
+ print(f"Evaluating category: {category} ({len(pairs)} pairs)")
323
+ results[category] = evaluate_model(model, pairs, batch_size=batch_size)
324
+
325
+ return results
326
+
327
+
328
+ def get_rank_distribution(
329
+ similarity_matrix: np.ndarray,
330
+ bins: list[int] = [1, 5, 10, 50, 100],
331
+ ) -> dict[str, int]:
332
+ """
333
+ Get distribution of true match ranks.
334
+
335
+ Args:
336
+ similarity_matrix: (N, N) similarity matrix
337
+ bins: Bin boundaries for histogram
338
+
339
+ Returns:
340
+ Dict mapping bin labels to counts
341
+ """
342
+ n = similarity_matrix.shape[0]
343
+ rankings = np.argsort(-similarity_matrix, axis=1)
344
+
345
+ # Find true rank for each query
346
+ true_ranks = np.zeros(n, dtype=int)
347
+ for i in range(n):
348
+ true_ranks[i] = np.where(rankings[i] == i)[0][0]
349
+
350
+ # Create histogram
351
+ distribution = {}
352
+ prev_bin = 0
353
+ for bin_edge in bins:
354
+ count = np.sum((true_ranks >= prev_bin) & (true_ranks < bin_edge))
355
+ label = f"{prev_bin+1}-{bin_edge}" if prev_bin > 0 else f"Top {bin_edge}"
356
+ distribution[label] = int(count)
357
+ prev_bin = bin_edge
358
+
359
+ # Count remaining
360
+ remaining = np.sum(true_ranks >= bins[-1])
361
+ distribution[f">{bins[-1]}"] = int(remaining)
362
+
363
+ return distribution
364
+
365
+
366
+ if __name__ == "__main__":
367
+ # Test with sample data
368
+ print("Testing evaluation functions...")
369
+
370
+ # Create sample similarity matrix (perfect retrieval)
371
+ n = 100
372
+ perfect_matrix = np.eye(n) + np.random.randn(n, n) * 0.1
373
+
374
+ metrics = compute_retrieval_metrics(perfect_matrix)
375
+ print(f"Perfect retrieval metrics: {metrics}")
376
+
377
+ # Test with random matrix
378
+ random_matrix = np.random.randn(n, n)
379
+ random_matrix = random_matrix / np.linalg.norm(random_matrix, axis=1, keepdims=True)
380
+ random_matrix = np.dot(random_matrix, random_matrix.T)
381
+
382
+ metrics = compute_retrieval_metrics(random_matrix)
383
+ print(f"Random retrieval metrics: {metrics}")
384
+
models.py ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and embedding interface for the Rabbinic embedding benchmark.
3
+
4
+ Supports:
5
+ - Curated models from Hugging Face (sentence-transformers)
6
+ - Any Hugging Face sentence-transformer model
7
+ - API-based models (OpenAI, Voyage AI, Google Gemini)
8
+ """
9
+
10
+ import os
11
+ from abc import ABC, abstractmethod
12
+ from typing import Optional
13
+ import numpy as np
14
+
15
+ # Curated local models known to work well for multilingual tasks
16
+ CURATED_MODELS = {
17
+ "intfloat/multilingual-e5-large": {
18
+ "name": "Multilingual E5 Large",
19
+ "description": "Strong multilingual model from Microsoft, 560M params",
20
+ "type": "local",
21
+ "query_prefix": "query: ",
22
+ "passage_prefix": "passage: ",
23
+ },
24
+ "intfloat/multilingual-e5-base": {
25
+ "name": "Multilingual E5 Base",
26
+ "description": "Smaller multilingual E5, 278M params",
27
+ "type": "local",
28
+ "query_prefix": "query: ",
29
+ "passage_prefix": "passage: ",
30
+ },
31
+ "sentence-transformers/paraphrase-multilingual-mpnet-base-v2": {
32
+ "name": "Multilingual MPNet",
33
+ "description": "Classic multilingual sentence transformer, 278M params",
34
+ "type": "local",
35
+ "query_prefix": "",
36
+ "passage_prefix": "",
37
+ },
38
+ "BAAI/bge-m3": {
39
+ "name": "BGE-M3",
40
+ "description": "Multi-lingual, multi-functionality, multi-granularity model from BAAI",
41
+ "type": "local",
42
+ "query_prefix": "",
43
+ "passage_prefix": "",
44
+ },
45
+ "intfloat/e5-mistral-7b-instruct": {
46
+ "name": "E5 Mistral 7B",
47
+ "description": "Large instruction-tuned embedding model, 7B params (requires GPU)",
48
+ "type": "local",
49
+ "query_prefix": "Instruct: Retrieve semantically similar text\nQuery: ",
50
+ "passage_prefix": "",
51
+ },
52
+ "Alibaba-NLP/gte-multilingual-base": {
53
+ "name": "GTE Multilingual Base",
54
+ "description": "General Text Embeddings multilingual model from Alibaba",
55
+ "type": "local",
56
+ "query_prefix": "",
57
+ "passage_prefix": "",
58
+ },
59
+ }
60
+
61
+ # API-based models
62
+ API_MODELS = {
63
+ "openai/text-embedding-3-large": {
64
+ "name": "OpenAI text-embedding-3-large",
65
+ "description": "OpenAI's best embedding model, 3072 dimensions (API key required)",
66
+ "type": "openai",
67
+ "model_name": "text-embedding-3-large",
68
+ "dimensions": 3072,
69
+ },
70
+ "openai/text-embedding-3-small": {
71
+ "name": "OpenAI text-embedding-3-small",
72
+ "description": "OpenAI's efficient embedding model, 1536 dimensions (API key required)",
73
+ "type": "openai",
74
+ "model_name": "text-embedding-3-small",
75
+ "dimensions": 1536,
76
+ },
77
+ "openai/text-embedding-ada-002": {
78
+ "name": "OpenAI Ada 002",
79
+ "description": "OpenAI's legacy embedding model, 1536 dimensions (API key required)",
80
+ "type": "openai",
81
+ "model_name": "text-embedding-ada-002",
82
+ "dimensions": 1536,
83
+ },
84
+ "voyage/voyage-3.5": {
85
+ "name": "Voyage AI voyage-3.5",
86
+ "description": "Voyage AI's latest embedding model (API key required)",
87
+ "type": "voyage",
88
+ "model_name": "voyage-3.5",
89
+ "dimensions": 1024,
90
+ },
91
+ "voyage/voyage-3.5-lite": {
92
+ "name": "Voyage AI voyage-3.5-lite",
93
+ "description": "Voyage AI's efficient embedding model (API key required)",
94
+ "type": "voyage",
95
+ "model_name": "voyage-3.5-lite",
96
+ "dimensions": 1024,
97
+ },
98
+ "voyage/voyage-3": {
99
+ "name": "Voyage AI voyage-3",
100
+ "description": "Voyage AI's general purpose embedding model (API key required)",
101
+ "type": "voyage",
102
+ "model_name": "voyage-3",
103
+ "dimensions": 1024,
104
+ },
105
+ "voyage/voyage-3-lite": {
106
+ "name": "Voyage AI voyage-3-lite",
107
+ "description": "Voyage AI's lightweight embedding model (API key required)",
108
+ "type": "voyage",
109
+ "model_name": "voyage-3-lite",
110
+ "dimensions": 512,
111
+ },
112
+ "voyage/voyage-multilingual-2": {
113
+ "name": "Voyage AI voyage-multilingual-2",
114
+ "description": "Voyage AI's multilingual embedding model, optimized for non-English (API key required)",
115
+ "type": "voyage",
116
+ "model_name": "voyage-multilingual-2",
117
+ "dimensions": 1024,
118
+ },
119
+ "gemini/gemini-embedding-001": {
120
+ "name": "Gemini Embedding 001",
121
+ "description": "Google's Gemini embedding model, 3072 dimensions (API key required)",
122
+ "type": "gemini",
123
+ "model_name": "gemini-embedding-001",
124
+ "dimensions": 3072,
125
+ },
126
+ "gemini/gemini-embedding-001-768": {
127
+ "name": "Gemini Embedding 001 (768d)",
128
+ "description": "Google's Gemini embedding model, 768 dimensions (API key required)",
129
+ "type": "gemini",
130
+ "model_name": "gemini-embedding-001",
131
+ "dimensions": 768,
132
+ },
133
+ "gemini/gemini-embedding-001-1536": {
134
+ "name": "Gemini Embedding 001 (1536d)",
135
+ "description": "Google's Gemini embedding model, 1536 dimensions (API key required)",
136
+ "type": "gemini",
137
+ "model_name": "gemini-embedding-001",
138
+ "dimensions": 1536,
139
+ },
140
+ }
141
+
142
+ # Merge all models for easy lookup
143
+ ALL_MODELS = {**CURATED_MODELS, **API_MODELS}
144
+
145
+
146
+ class BaseEmbeddingModel(ABC):
147
+ """Abstract base class for embedding models."""
148
+
149
+ model_id: str
150
+ embedding_dim: int
151
+
152
+ @abstractmethod
153
+ def encode(
154
+ self,
155
+ texts: list[str],
156
+ is_query: bool = False,
157
+ batch_size: int = 32,
158
+ show_progress: bool = True,
159
+ normalize: bool = True,
160
+ ) -> np.ndarray:
161
+ """Encode texts to embeddings."""
162
+ pass
163
+
164
+ @property
165
+ @abstractmethod
166
+ def name(self) -> str:
167
+ """Get display name for the model."""
168
+ pass
169
+
170
+ @property
171
+ @abstractmethod
172
+ def description(self) -> str:
173
+ """Get description for the model."""
174
+ pass
175
+
176
+ def encode_pairs(
177
+ self,
178
+ he_texts: list[str],
179
+ en_texts: list[str],
180
+ batch_size: int = 32,
181
+ show_progress: bool = True,
182
+ ) -> tuple[np.ndarray, np.ndarray]:
183
+ """
184
+ Encode parallel Hebrew/English text pairs.
185
+
186
+ Args:
187
+ he_texts: Hebrew/Aramaic source texts
188
+ en_texts: English translations
189
+ batch_size: Batch size for encoding
190
+ show_progress: Whether to show progress bar
191
+
192
+ Returns:
193
+ Tuple of (hebrew_embeddings, english_embeddings)
194
+ """
195
+ he_embeddings = self.encode(
196
+ he_texts,
197
+ is_query=True,
198
+ batch_size=batch_size,
199
+ show_progress=show_progress,
200
+ )
201
+
202
+ en_embeddings = self.encode(
203
+ en_texts,
204
+ is_query=False,
205
+ batch_size=batch_size,
206
+ show_progress=show_progress,
207
+ )
208
+
209
+ return he_embeddings, en_embeddings
210
+
211
+
212
+ class EmbeddingModel(BaseEmbeddingModel):
213
+ """
214
+ Wrapper for sentence-transformer models with consistent interface.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ model_id: str,
220
+ device: Optional[str] = None,
221
+ max_length: int = 512,
222
+ ):
223
+ """
224
+ Initialize the embedding model.
225
+
226
+ Args:
227
+ model_id: Hugging Face model ID
228
+ device: Device to use ('cuda', 'cpu', or None for auto)
229
+ max_length: Maximum sequence length for tokenization
230
+ """
231
+ from sentence_transformers import SentenceTransformer
232
+ import torch
233
+
234
+ self.model_id = model_id
235
+ self.max_length = max_length
236
+
237
+ # Auto-detect device
238
+ if device is None:
239
+ device = "cuda" if torch.cuda.is_available() else "cpu"
240
+ self.device = device
241
+
242
+ # Get model config if it's a curated model
243
+ self.config = CURATED_MODELS.get(model_id, {
244
+ "name": model_id.split("/")[-1],
245
+ "description": "Custom model",
246
+ "type": "local",
247
+ "query_prefix": "",
248
+ "passage_prefix": "",
249
+ })
250
+
251
+ # Load the model
252
+ print(f"Loading model: {model_id} on {device}")
253
+ self.model = SentenceTransformer(model_id, device=device)
254
+
255
+ # Set max sequence length if supported
256
+ if hasattr(self.model, "max_seq_length"):
257
+ self.model.max_seq_length = min(max_length, self.model.max_seq_length)
258
+
259
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
260
+ print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
261
+
262
+ def encode(
263
+ self,
264
+ texts: list[str],
265
+ is_query: bool = False,
266
+ batch_size: int = 32,
267
+ show_progress: bool = True,
268
+ normalize: bool = True,
269
+ ) -> np.ndarray:
270
+ """
271
+ Encode texts to embeddings.
272
+
273
+ Args:
274
+ texts: List of texts to encode
275
+ is_query: Whether these are queries (vs passages) for asymmetric models
276
+ batch_size: Batch size for encoding
277
+ show_progress: Whether to show progress bar
278
+ normalize: Whether to L2-normalize embeddings
279
+
280
+ Returns:
281
+ numpy array of shape (len(texts), embedding_dim)
282
+ """
283
+ # Add prefix if needed (for E5-style models)
284
+ prefix = self.config["query_prefix"] if is_query else self.config["passage_prefix"]
285
+ if prefix:
286
+ texts = [prefix + t for t in texts]
287
+
288
+ embeddings = self.model.encode(
289
+ texts,
290
+ batch_size=batch_size,
291
+ show_progress_bar=show_progress,
292
+ normalize_embeddings=normalize,
293
+ convert_to_numpy=True,
294
+ )
295
+
296
+ return embeddings
297
+
298
+ @property
299
+ def name(self) -> str:
300
+ """Get display name for the model."""
301
+ return self.config.get("name", self.model_id)
302
+
303
+ @property
304
+ def description(self) -> str:
305
+ """Get description for the model."""
306
+ return self.config.get("description", "")
307
+
308
+
309
+ class OpenAIEmbeddingModel(BaseEmbeddingModel):
310
+ """
311
+ Wrapper for OpenAI embedding API with consistent interface.
312
+ """
313
+
314
+ # OpenAI embedding models have an 8191 token limit
315
+ MAX_TOKENS = 8191
316
+
317
+ def __init__(
318
+ self,
319
+ model_id: str,
320
+ api_key: Optional[str] = None,
321
+ ):
322
+ """
323
+ Initialize the OpenAI embedding model.
324
+
325
+ Args:
326
+ model_id: Model ID in format 'openai/model-name'
327
+ api_key: OpenAI API key (or uses OPENAI_API_KEY env var)
328
+ """
329
+ try:
330
+ from openai import OpenAI
331
+ except ImportError:
332
+ raise ImportError(
333
+ "OpenAI package not installed. Install with: pip install openai"
334
+ )
335
+
336
+ self.model_id = model_id
337
+
338
+ # Get API key from parameter or environment
339
+ api_key = api_key or os.environ.get("OPENAI_API_KEY")
340
+ if not api_key:
341
+ raise ValueError(
342
+ "OpenAI API key required. Set OPENAI_API_KEY environment variable "
343
+ "or pass api_key parameter."
344
+ )
345
+
346
+ self.client = OpenAI(api_key=api_key)
347
+
348
+ # Get model config
349
+ self.config = API_MODELS.get(model_id, {
350
+ "name": model_id,
351
+ "description": "OpenAI embedding model",
352
+ "type": "openai",
353
+ "model_name": model_id.replace("openai/", ""),
354
+ "dimensions": 1536,
355
+ })
356
+
357
+ self._model_name = self.config["model_name"]
358
+ self.embedding_dim = self.config["dimensions"]
359
+
360
+ # Initialize tokenizer for truncation
361
+ self._encoding = None
362
+ try:
363
+ import tiktoken
364
+ self._encoding = tiktoken.encoding_for_model(self._model_name)
365
+ except Exception:
366
+ # Fall back to cl100k_base which is used by embedding models
367
+ try:
368
+ import tiktoken
369
+ self._encoding = tiktoken.get_encoding("cl100k_base")
370
+ except Exception:
371
+ print("Warning: tiktoken not available, using character-based truncation")
372
+
373
+ print(f"Initialized OpenAI embedding model: {self._model_name}")
374
+ print(f"Embedding dimension: {self.embedding_dim}")
375
+
376
+ def _truncate_text(self, text: str) -> str:
377
+ """Truncate text to fit within token limit."""
378
+ if self._encoding is not None:
379
+ # Use tiktoken for accurate token counting
380
+ tokens = self._encoding.encode(text)
381
+ if len(tokens) > self.MAX_TOKENS:
382
+ tokens = tokens[:self.MAX_TOKENS]
383
+ return self._encoding.decode(tokens)
384
+ return text
385
+ else:
386
+ # Fallback: rough character-based truncation
387
+ # Assume ~3 chars per token for Hebrew/mixed text (conservative)
388
+ max_chars = self.MAX_TOKENS * 3
389
+ if len(text) > max_chars:
390
+ return text[:max_chars]
391
+ return text
392
+
393
+ def encode(
394
+ self,
395
+ texts: list[str],
396
+ is_query: bool = False,
397
+ batch_size: int = 100, # OpenAI supports larger batches
398
+ show_progress: bool = True,
399
+ normalize: bool = True,
400
+ ) -> np.ndarray:
401
+ """
402
+ Encode texts to embeddings using OpenAI API.
403
+
404
+ Args:
405
+ texts: List of texts to encode
406
+ is_query: Not used for OpenAI (symmetric embeddings)
407
+ batch_size: Batch size for API calls
408
+ show_progress: Whether to show progress bar
409
+ normalize: Whether to L2-normalize embeddings (OpenAI already normalizes)
410
+
411
+ Returns:
412
+ numpy array of shape (len(texts), embedding_dim)
413
+ """
414
+ import time
415
+
416
+ all_embeddings = []
417
+ total_batches = (len(texts) + batch_size - 1) // batch_size
418
+
419
+ for i in range(0, len(texts), batch_size):
420
+ batch = texts[i:i + batch_size]
421
+ batch_num = i // batch_size + 1
422
+
423
+ if show_progress:
424
+ print(f" Encoding batch {batch_num}/{total_batches}...")
425
+
426
+ # Retry logic for API calls
427
+ max_retries = 3
428
+ for attempt in range(max_retries):
429
+ try:
430
+ response = self.client.embeddings.create(
431
+ model=self._model_name,
432
+ input=batch,
433
+ )
434
+
435
+ # Extract embeddings from response
436
+ batch_embeddings = [item.embedding for item in response.data]
437
+ all_embeddings.extend(batch_embeddings)
438
+ break
439
+
440
+ except Exception as e:
441
+ if attempt < max_retries - 1:
442
+ wait_time = 2 ** attempt
443
+ print(f" API error, retrying in {wait_time}s: {e}")
444
+ time.sleep(wait_time)
445
+ else:
446
+ raise RuntimeError(f"OpenAI API error after {max_retries} retries: {e}")
447
+
448
+ # Small delay to avoid rate limits
449
+ if i + batch_size < len(texts):
450
+ time.sleep(0.1)
451
+
452
+ embeddings = np.array(all_embeddings, dtype=np.float32)
453
+
454
+ # OpenAI embeddings are already normalized, but normalize if requested
455
+ if normalize:
456
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
457
+ embeddings = embeddings / np.maximum(norms, 1e-10)
458
+
459
+ return embeddings
460
+
461
+ @property
462
+ def name(self) -> str:
463
+ """Get display name for the model."""
464
+ return self.config.get("name", self.model_id)
465
+
466
+ @property
467
+ def description(self) -> str:
468
+ """Get description for the model."""
469
+ return self.config.get("description", "")
470
+
471
+
472
+ class VoyageEmbeddingModel(BaseEmbeddingModel):
473
+ """
474
+ Wrapper for Voyage AI embedding API with consistent interface.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ model_id: str,
480
+ api_key: Optional[str] = None,
481
+ ):
482
+ """
483
+ Initialize the Voyage AI embedding model.
484
+
485
+ Args:
486
+ model_id: Model ID in format 'voyage/model-name'
487
+ api_key: Voyage API key (or uses VOYAGE_API_KEY env var)
488
+ """
489
+ try:
490
+ import voyageai
491
+ except ImportError:
492
+ raise ImportError(
493
+ "Voyage AI package not installed. Install with: pip install voyageai"
494
+ )
495
+
496
+ self.model_id = model_id
497
+
498
+ # Get API key from parameter or environment
499
+ api_key = api_key or os.environ.get("VOYAGE_API_KEY")
500
+ if not api_key:
501
+ raise ValueError(
502
+ "Voyage API key required. Set VOYAGE_API_KEY environment variable "
503
+ "or pass api_key parameter."
504
+ )
505
+
506
+ self.client = voyageai.Client(api_key=api_key)
507
+
508
+ # Get model config
509
+ self.config = API_MODELS.get(model_id, {
510
+ "name": model_id,
511
+ "description": "Voyage AI embedding model",
512
+ "type": "voyage",
513
+ "model_name": model_id.replace("voyage/", ""),
514
+ "dimensions": 1024, # Default dimension
515
+ })
516
+
517
+ self._model_name = self.config["model_name"]
518
+ self.embedding_dim = self.config["dimensions"]
519
+
520
+ print(f"Initialized Voyage AI embedding model: {self._model_name}")
521
+ print(f"Embedding dimension: {self.embedding_dim}")
522
+
523
+ def encode(
524
+ self,
525
+ texts: list[str],
526
+ is_query: bool = False,
527
+ batch_size: int = 128, # Voyage supports larger batches
528
+ show_progress: bool = True,
529
+ normalize: bool = True,
530
+ ) -> np.ndarray:
531
+ """
532
+ Encode texts to embeddings using Voyage AI API.
533
+
534
+ Args:
535
+ texts: List of texts to encode
536
+ is_query: Whether these are queries (Voyage supports input_type)
537
+ batch_size: Batch size for API calls
538
+ show_progress: Whether to show progress bar
539
+ normalize: Whether to L2-normalize embeddings
540
+
541
+ Returns:
542
+ numpy array of shape (len(texts), embedding_dim)
543
+ """
544
+ import time
545
+
546
+ all_embeddings = []
547
+ total_batches = (len(texts) + batch_size - 1) // batch_size
548
+
549
+ # Voyage supports input_type for asymmetric embeddings
550
+ input_type = "query" if is_query else "document"
551
+
552
+ for i in range(0, len(texts), batch_size):
553
+ batch = texts[i:i + batch_size]
554
+ batch_num = i // batch_size + 1
555
+
556
+ if show_progress:
557
+ print(f" Encoding batch {batch_num}/{total_batches}...")
558
+
559
+ # Retry logic for API calls
560
+ max_retries = 3
561
+ for attempt in range(max_retries):
562
+ try:
563
+ result = self.client.embed(
564
+ batch,
565
+ model=self._model_name,
566
+ input_type=input_type,
567
+ )
568
+
569
+ # Extract embeddings from response
570
+ batch_embeddings = result.embeddings
571
+ all_embeddings.extend(batch_embeddings)
572
+ break
573
+
574
+ except Exception as e:
575
+ if attempt < max_retries - 1:
576
+ wait_time = 2 ** attempt
577
+ print(f" API error, retrying in {wait_time}s: {e}")
578
+ time.sleep(wait_time)
579
+ else:
580
+ raise RuntimeError(f"Voyage AI API error after {max_retries} retries: {e}")
581
+
582
+ # Small delay to avoid rate limits
583
+ if i + batch_size < len(texts):
584
+ time.sleep(0.1)
585
+
586
+ embeddings = np.array(all_embeddings, dtype=np.float32)
587
+
588
+ # Normalize if requested
589
+ if normalize:
590
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
591
+ embeddings = embeddings / np.maximum(norms, 1e-10)
592
+
593
+ return embeddings
594
+
595
+ @property
596
+ def name(self) -> str:
597
+ """Get display name for the model."""
598
+ return self.config.get("name", self.model_id)
599
+
600
+ @property
601
+ def description(self) -> str:
602
+ """Get description for the model."""
603
+ return self.config.get("description", "")
604
+
605
+
606
+ class GeminiEmbeddingModel(BaseEmbeddingModel):
607
+ """
608
+ Wrapper for Google Gemini embedding API with consistent interface.
609
+ """
610
+
611
+ def __init__(
612
+ self,
613
+ model_id: str,
614
+ api_key: Optional[str] = None,
615
+ ):
616
+ """
617
+ Initialize the Gemini embedding model.
618
+
619
+ Args:
620
+ model_id: Model ID in format 'gemini/model-name'
621
+ api_key: Gemini API key (optional - can use GEMINI_API_KEY env var
622
+ or Google Cloud Application Default Credentials)
623
+ """
624
+ try:
625
+ from google import genai
626
+ except ImportError:
627
+ raise ImportError(
628
+ "Google GenAI package not installed. Install with: pip install google-genai"
629
+ )
630
+
631
+ self.model_id = model_id
632
+
633
+ # Get API key from parameter or environment (optional - ADC also works)
634
+ api_key = api_key or os.environ.get("GEMINI_API_KEY")
635
+
636
+ # Create client - if no API key, will use Application Default Credentials
637
+ if api_key:
638
+ self.client = genai.Client(api_key=api_key)
639
+ else:
640
+ # Use Application Default Credentials (gcloud auth application-default login)
641
+ self.client = genai.Client()
642
+
643
+ # Get model config
644
+ self.config = API_MODELS.get(model_id, {
645
+ "name": model_id,
646
+ "description": "Gemini embedding model",
647
+ "type": "gemini",
648
+ "model_name": model_id.replace("gemini/", "").split("-768")[0].split("-1536")[0],
649
+ "dimensions": 3072, # Default dimension
650
+ })
651
+
652
+ self._model_name = self.config["model_name"]
653
+ self.embedding_dim = self.config["dimensions"]
654
+
655
+ print(f"Initialized Gemini embedding model: {self._model_name}")
656
+ print(f"Embedding dimension: {self.embedding_dim}")
657
+
658
+ def encode(
659
+ self,
660
+ texts: list[str],
661
+ is_query: bool = False,
662
+ batch_size: int = 20, # Smaller batches to avoid rate limits
663
+ show_progress: bool = True,
664
+ normalize: bool = True,
665
+ ) -> np.ndarray:
666
+ """
667
+ Encode texts to embeddings using Gemini API.
668
+
669
+ Args:
670
+ texts: List of texts to encode
671
+ is_query: Whether these are queries (uses RETRIEVAL_QUERY vs RETRIEVAL_DOCUMENT)
672
+ batch_size: Batch size for API calls (smaller for Gemini to avoid rate limits)
673
+ show_progress: Whether to show progress bar
674
+ normalize: Whether to L2-normalize embeddings
675
+
676
+ Returns:
677
+ numpy array of shape (len(texts), embedding_dim)
678
+ """
679
+ import time
680
+ import random
681
+ from google.genai import types
682
+
683
+ all_embeddings = []
684
+ total_batches = (len(texts) + batch_size - 1) // batch_size
685
+
686
+ # Gemini supports task_type for asymmetric embeddings
687
+ task_type = "RETRIEVAL_QUERY" if is_query else "RETRIEVAL_DOCUMENT"
688
+
689
+ for i in range(0, len(texts), batch_size):
690
+ batch = texts[i:i + batch_size]
691
+ batch_num = i // batch_size + 1
692
+
693
+ if show_progress:
694
+ print(f" Encoding batch {batch_num}/{total_batches}...")
695
+
696
+ # Retry logic with exponential backoff for rate limits
697
+ max_retries = 8
698
+ base_delay = 2.0
699
+
700
+ for attempt in range(max_retries):
701
+ try:
702
+ # Build config with task type and output dimensionality
703
+ embed_config = types.EmbedContentConfig(
704
+ task_type=task_type,
705
+ output_dimensionality=self.embedding_dim,
706
+ )
707
+
708
+ result = self.client.models.embed_content(
709
+ model=self._model_name,
710
+ contents=batch,
711
+ config=embed_config,
712
+ )
713
+
714
+ # Extract embeddings from response
715
+ batch_embeddings = [e.values for e in result.embeddings]
716
+ all_embeddings.extend(batch_embeddings)
717
+ break
718
+
719
+ except Exception as e:
720
+ error_str = str(e)
721
+ is_rate_limit = "429" in error_str or "RESOURCE_EXHAUSTED" in error_str
722
+
723
+ if attempt < max_retries - 1:
724
+ # Exponential backoff with jitter
725
+ # Longer waits for rate limit errors
726
+ if is_rate_limit:
727
+ wait_time = base_delay * (2 ** attempt) + random.uniform(1, 5)
728
+ print(f" Rate limited, waiting {wait_time:.1f}s before retry {attempt + 2}/{max_retries}...")
729
+ else:
730
+ wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1)
731
+ print(f" API error, retrying in {wait_time:.1f}s: {e}")
732
+ time.sleep(wait_time)
733
+ else:
734
+ raise RuntimeError(f"Gemini API error after {max_retries} retries: {e}")
735
+
736
+ # Delay between batches to avoid rate limits (longer for Gemini)
737
+ if i + batch_size < len(texts):
738
+ time.sleep(0.5)
739
+
740
+ embeddings = np.array(all_embeddings, dtype=np.float32)
741
+
742
+ # Normalize if requested (Gemini's 3072d is normalized, but smaller dims need it)
743
+ if normalize:
744
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
745
+ embeddings = embeddings / np.maximum(norms, 1e-10)
746
+
747
+ return embeddings
748
+
749
+ @property
750
+ def name(self) -> str:
751
+ """Get display name for the model."""
752
+ return self.config.get("name", self.model_id)
753
+
754
+ @property
755
+ def description(self) -> str:
756
+ """Get description for the model."""
757
+ return self.config.get("description", "")
758
+
759
+
760
+ def get_curated_model_choices() -> list[tuple[str, str]]:
761
+ """
762
+ Get list of curated local models for UI dropdown.
763
+
764
+ Returns:
765
+ List of (model_id, display_name) tuples
766
+ """
767
+ return [
768
+ (model_id, f"{info['name']} - {info['description']}")
769
+ for model_id, info in CURATED_MODELS.items()
770
+ ]
771
+
772
+
773
+ def get_api_model_choices() -> list[tuple[str, str]]:
774
+ """
775
+ Get list of API-based models for UI dropdown.
776
+
777
+ Returns:
778
+ List of (model_id, display_name) tuples
779
+ """
780
+ return [
781
+ (model_id, f"{info['name']} - {info['description']}")
782
+ for model_id, info in API_MODELS.items()
783
+ ]
784
+
785
+
786
+ def get_all_model_choices() -> list[tuple[str, str]]:
787
+ """
788
+ Get list of all models (local + API) for UI dropdown.
789
+
790
+ Returns:
791
+ List of (model_id, display_name) tuples
792
+ """
793
+ return get_curated_model_choices() + get_api_model_choices()
794
+
795
+
796
+ def is_api_model(model_id: str) -> bool:
797
+ """Check if a model ID is an API-based model."""
798
+ model_id = model_id.strip()
799
+
800
+ # Check if it's in API_MODELS
801
+ if model_id in API_MODELS:
802
+ return True
803
+
804
+ # Check if it starts with known API prefixes
805
+ if model_id.startswith("openai/"):
806
+ return True
807
+ if model_id.startswith("voyage/"):
808
+ return True
809
+ if model_id.startswith("gemini/"):
810
+ return True
811
+
812
+ return False
813
+
814
+
815
+ def load_model(
816
+ model_id: str,
817
+ device: Optional[str] = None,
818
+ api_key: Optional[str] = None,
819
+ ) -> BaseEmbeddingModel:
820
+ """
821
+ Load an embedding model by ID.
822
+
823
+ Args:
824
+ model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
825
+ device: Device to use (for local models only)
826
+ api_key: API key (for API-based models, or uses environment variable)
827
+
828
+ Returns:
829
+ Loaded embedding model instance
830
+ """
831
+ model_id = model_id.strip()
832
+
833
+ # Check if this is an API model
834
+ if is_api_model(model_id):
835
+ # Check model type from config or prefix
836
+ model_config = API_MODELS.get(model_id, {})
837
+ model_type = model_config.get("type", "")
838
+
839
+ if model_type == "voyage" or model_id.startswith("voyage/"):
840
+ return VoyageEmbeddingModel(model_id, api_key=api_key)
841
+ elif model_type == "gemini" or model_id.startswith("gemini/"):
842
+ return GeminiEmbeddingModel(model_id, api_key=api_key)
843
+ elif model_type == "openai" or model_id.startswith("openai/"):
844
+ return OpenAIEmbeddingModel(model_id, api_key=api_key)
845
+ else:
846
+ raise ValueError(f"Unknown API model type: {model_id}")
847
+
848
+ # Otherwise, load as a local sentence-transformer model
849
+ return EmbeddingModel(model_id, device=device)
850
+
851
+
852
+ def validate_model_id(model_id: str) -> tuple[bool, str]:
853
+ """
854
+ Check if a model ID is valid and loadable.
855
+
856
+ Args:
857
+ model_id: The model ID to validate
858
+
859
+ Returns:
860
+ Tuple of (is_valid, error_message)
861
+ """
862
+ if not model_id or not model_id.strip():
863
+ return False, "Model ID cannot be empty"
864
+
865
+ model_id = model_id.strip()
866
+
867
+ # Check if it's a curated local model
868
+ if model_id in CURATED_MODELS:
869
+ return True, ""
870
+
871
+ # Check if it's a known API model
872
+ if model_id in API_MODELS:
873
+ return True, ""
874
+
875
+ # Check for OpenAI models
876
+ if model_id.startswith("openai/"):
877
+ return True, ""
878
+
879
+ # Check for Voyage AI models
880
+ if model_id.startswith("voyage/"):
881
+ return True, ""
882
+
883
+ # Check for Gemini models
884
+ if model_id.startswith("gemini/"):
885
+ return True, ""
886
+
887
+ # For custom models, check if it looks like a valid HF model ID
888
+ if "/" not in model_id:
889
+ return False, "Model ID should be in format 'organization/model-name'"
890
+
891
+ # Could add an API check here, but that would slow down validation
892
+ return True, ""
893
+
894
+
895
+ def requires_api_key(model_id: str) -> bool:
896
+ """Check if a model requires an API key."""
897
+ return is_api_model(model_id)
898
+
899
+
900
+ def api_key_optional(model_id: str) -> bool:
901
+ """
902
+ Check if an API key is optional for this model.
903
+
904
+ Some providers (like Google Gemini) support Application Default Credentials
905
+ as an alternative to explicit API keys.
906
+ """
907
+ key_type = get_api_key_type(model_id)
908
+ # Gemini supports ADC (gcloud auth application-default login)
909
+ return key_type == "gemini"
910
+
911
+
912
+ def get_api_key_type(model_id: str) -> Optional[str]:
913
+ """
914
+ Get the type of API key required for a model.
915
+
916
+ Args:
917
+ model_id: The model ID
918
+
919
+ Returns:
920
+ 'openai', 'voyage', or None if no API key needed
921
+ """
922
+ if not is_api_model(model_id):
923
+ return None
924
+
925
+ model_id = model_id.strip()
926
+ model_config = API_MODELS.get(model_id, {})
927
+ model_type = model_config.get("type", "")
928
+
929
+ if model_type == "voyage" or model_id.startswith("voyage/"):
930
+ return "voyage"
931
+ elif model_type == "gemini" or model_id.startswith("gemini/"):
932
+ return "gemini"
933
+ elif model_type == "openai" or model_id.startswith("openai/"):
934
+ return "openai"
935
+
936
+ return None
937
+
938
+
939
+ def get_api_key_env_var(model_id: str) -> Optional[str]:
940
+ """
941
+ Get the environment variable name for the API key required by a model.
942
+
943
+ Args:
944
+ model_id: The model ID
945
+
946
+ Returns:
947
+ Environment variable name or None
948
+ """
949
+ key_type = get_api_key_type(model_id)
950
+ if key_type == "openai":
951
+ return "OPENAI_API_KEY"
952
+ elif key_type == "voyage":
953
+ return "VOYAGE_API_KEY"
954
+ elif key_type == "gemini":
955
+ return "GEMINI_API_KEY"
956
+ return None
957
+
958
+
959
+ if __name__ == "__main__":
960
+ import argparse
961
+
962
+ parser = argparse.ArgumentParser(
963
+ description="Test embedding model loading and encoding"
964
+ )
965
+ parser.add_argument(
966
+ "--local",
967
+ action="store_true",
968
+ help="Test only local sentence-transformer models",
969
+ )
970
+ parser.add_argument(
971
+ "--remote",
972
+ action="store_true",
973
+ help="Test only remote/API models (requires API keys)",
974
+ )
975
+ parser.add_argument(
976
+ "--model",
977
+ type=str,
978
+ default=None,
979
+ help="Test a specific model ID",
980
+ )
981
+
982
+ args = parser.parse_args()
983
+
984
+ # If neither flag specified, test both
985
+ test_local = args.local or (not args.local and not args.remote)
986
+ test_remote = args.remote or (not args.local and not args.remote)
987
+
988
+ print("Testing model loading...")
989
+
990
+ print(f"\nLocal models available:")
991
+ for model_id, display in get_curated_model_choices():
992
+ print(f" - {display}")
993
+
994
+ print(f"\nAPI models available:")
995
+ for model_id, display in get_api_model_choices():
996
+ print(f" - {display}")
997
+
998
+ # Test texts
999
+ test_texts = [
1000
+ "בראשית ברא אלהים את השמים ואת הארץ",
1001
+ "In the beginning God created the heaven and the earth",
1002
+ ]
1003
+
1004
+ def run_model_test(model_id: str, model_type: str):
1005
+ """Run a test for a specific model."""
1006
+ print(f"\n{'='*60}")
1007
+ print(f"Testing {model_type}: {model_id}")
1008
+ print("="*60)
1009
+
1010
+ try:
1011
+ model = load_model(model_id)
1012
+
1013
+ embeddings = model.encode(test_texts, show_progress=False)
1014
+ print(f"\nEncoded {len(test_texts)} texts")
1015
+ print(f"Embedding shape: {embeddings.shape}")
1016
+
1017
+ similarity = np.dot(embeddings[0], embeddings[1])
1018
+ print(f"Cosine similarity between Hebrew and English: {similarity:.4f}")
1019
+ return True
1020
+ except Exception as e:
1021
+ print(f"Test failed: {e}")
1022
+ return False
1023
+
1024
+ # Test specific model if provided
1025
+ if args.model:
1026
+ run_model_test(args.model, "specified model")
1027
+ else:
1028
+ # Test local model
1029
+ if test_local:
1030
+ run_model_test(
1031
+ "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
1032
+ "local sentence-transformer model"
1033
+ )
1034
+
1035
+ # Test API models
1036
+ if test_remote:
1037
+ # Test OpenAI model
1038
+ if os.environ.get("OPENAI_API_KEY"):
1039
+ run_model_test(
1040
+ "openai/text-embedding-3-small",
1041
+ "OpenAI API model"
1042
+ )
1043
+ else:
1044
+ print("\n(Skipping OpenAI test - OPENAI_API_KEY not set)")
1045
+
1046
+ # Test Voyage AI model
1047
+ if os.environ.get("VOYAGE_API_KEY"):
1048
+ run_model_test(
1049
+ "voyage/voyage-3.5",
1050
+ "Voyage AI API model"
1051
+ )
1052
+ else:
1053
+ print("\n(Skipping Voyage AI test - VOYAGE_API_KEY not set)")
1054
+
1055
+ # Test Gemini model
1056
+ if os.environ.get("GEMINI_API_KEY"):
1057
+ run_model_test(
1058
+ "gemini/gemini-embedding-001",
1059
+ "Gemini API model"
1060
+ )
1061
+ else:
1062
+ print("\n(Skipping Gemini test - GEMINI_API_KEY not set)")
1063
+
remove_oversize_entries.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ One-time script to remove entries exceeding the OpenAI embedding token limit
3
+ from the benchmark dataset.
4
+ """
5
+
6
+ import json
7
+
8
+ # The refs to remove (from the token limit check report)
9
+ REFS_TO_REMOVE = [
10
+ "Shemot Rabbah.1:1",
11
+ "Bamidbar Rabbah.1:2",
12
+ "Bamidbar Rabbah.2:10",
13
+ "Shir HaShirim Rabbah.1.1:10",
14
+ "Eichah Rabbah.1:4",
15
+ "Eichah Rabbah.1:23",
16
+ "Eichah Rabbah.1:31",
17
+ "Ramban on Genesis.18:1",
18
+ "Ramban on Genesis.24:2",
19
+ "Ramban on Leviticus.1.9:1",
20
+ "Ramban on Numbers.16:1",
21
+ "Ramban on Numbers.24:1",
22
+ "Ramban on Deuteronomy.2.23:1",
23
+ ]
24
+
25
+ def main():
26
+ data_path = "benchmark_data/benchmark.json"
27
+
28
+ # Load the data
29
+ print(f"Loading data from: {data_path}")
30
+ with open(data_path, "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+
33
+ original_count = len(data)
34
+ print(f"Original entry count: {original_count}")
35
+
36
+ # Filter out the flagged entries
37
+ filtered_data = [entry for entry in data if entry["ref"] not in REFS_TO_REMOVE]
38
+
39
+ removed_count = original_count - len(filtered_data)
40
+ print(f"Removed {removed_count} entries")
41
+ print(f"New entry count: {len(filtered_data)}")
42
+
43
+ # Save the filtered data
44
+ print(f"Saving filtered data to: {data_path}")
45
+ with open(data_path, "w", encoding="utf-8") as f:
46
+ json.dump(filtered_data, f, ensure_ascii=False, indent=2)
47
+
48
+ print("Done!")
49
+
50
+ if __name__ == "__main__":
51
+ main()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space dependencies
2
+ gradio>=4.0.0
3
+ transformers>=4.36.0
4
+ sentence-transformers>=2.2.2
5
+ torch>=2.0.0
6
+
7
+ # Data processing
8
+ numpy>=1.24.0
9
+ pandas>=2.0.0
10
+ requests>=2.31.0
11
+
12
+ # Visualization
13
+ plotly>=5.18.0
14
+
15
+ # API-based embedding providers
16
+ openai>=1.0.0
17
+ tiktoken>=0.5.0
18
+ voyageai>=0.3.0
19
+ google-genai>=1.0.0
20
+
space_README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rabbinic Embedding Benchmark
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Rabbinic Hebrew/Aramaic Embedding Benchmark
14
+
15
+ Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic source texts and their English translations from Sefaria.
16
+
17
+ ## How It Works
18
+
19
+ Given a Hebrew/Aramaic text, can the model find its correct English translation from a pool of candidates? Models that excel at this task produce high-quality embeddings for Rabbinic literature.
20
+
21
+ ## Metrics
22
+
23
+ | Metric | Description |
24
+ |--------|-------------|
25
+ | **MRR** | Mean Reciprocal Rank (average of 1/rank of correct answer) |
26
+ | **Recall@k** | % of queries where correct translation is in top k results |
27
+ | **Bitext Accuracy** | True pair vs random pair classification |
28
+
29
+ ## Corpus
30
+
31
+ The benchmark includes diverse texts with English translations:
32
+
33
+ - **Talmud**: Bavli & Yerushalmi
34
+ - **Mishnah**: Selected tractates
35
+ - **Midrash**: Midrash Rabbah
36
+ - **Commentary**: Rashi, Ramban, Radak, Rabbeinu Behaye
37
+ - **Philosophy**: Guide for the Perplexed, Sefer HaIkkarim
38
+ - **Hasidic/Kabbalistic**: Likutei Moharan, Tomer Devorah, Kalach Pitchei Chokhmah
39
+ - **Mussar**: Chafetz Chaim, Kav HaYashar, Iggeret HaRamban
40
+ - **Halacha**: Sefer HaChinukh, Mishneh Torah
41
+
42
+ All texts sourced from [Sefaria](https://www.sefaria.org).
43
+
44
+ ## API Keys
45
+
46
+ For API-based models (OpenAI, Voyage AI, Gemini), you can either:
47
+ - Enter your API key in the interface (not stored)
48
+ - Set environment variables in Space settings: `OPENAI_API_KEY`, `VOYAGE_API_KEY`, `GEMINI_API_KEY`