Lev Israel commited on
Commit ·
018c4c5
0
Parent(s):
Initial Commit
Browse files- .gitignore +26 -0
- README.md +58 -0
- app.py +498 -0
- benchmark-stats.txt +14 -0
- benchmark_data/benchmark.json +0 -0
- build_benchmark.py +190 -0
- check_token_limits.py +306 -0
- data_loader.py +828 -0
- evaluation.py +384 -0
- models.py +1063 -0
- remove_oversize_entries.py +51 -0
- requirements.txt +20 -0
- space_README.md +48 -0
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
venv/
|
| 8 |
+
.venv/
|
| 9 |
+
env/
|
| 10 |
+
|
| 11 |
+
# IDE
|
| 12 |
+
.vscode/
|
| 13 |
+
.idea/
|
| 14 |
+
*.swp
|
| 15 |
+
*.swo
|
| 16 |
+
|
| 17 |
+
# OS
|
| 18 |
+
.DS_Store
|
| 19 |
+
Thumbs.db
|
| 20 |
+
|
| 21 |
+
# Secrets (never commit these!)
|
| 22 |
+
.env
|
| 23 |
+
*.env
|
| 24 |
+
|
| 25 |
+
# Leaderboard data (regenerated on Space)
|
| 26 |
+
benchmark_data/leaderboard.json
|
README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Rabbinic Hebrew/Aramaic Embedding Evaluation
|
| 2 |
+
|
| 3 |
+
A Hugging Face Space for evaluating embedding models on Rabbinic Hebrew and Aramaic texts using cross-lingual retrieval benchmarks.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This tool helps identify which embedding models best capture the semantics of Rabbinic Hebrew and Aramaic by measuring how well they align source texts with their English translations. Models that excel at this task are likely to produce high-quality embeddings for untranslated texts.
|
| 8 |
+
|
| 9 |
+
## Evaluation Approach
|
| 10 |
+
|
| 11 |
+
Given a Hebrew/Aramaic text, the benchmark tests whether the embedding model can find its correct English translation from a pool of candidates. This cross-lingual retrieval task measures semantic alignment across languages.
|
| 12 |
+
|
| 13 |
+
### Metrics
|
| 14 |
+
|
| 15 |
+
| Metric | Description |
|
| 16 |
+
|--------|-------------|
|
| 17 |
+
| **Recall@1** | % of queries where correct translation is the top result |
|
| 18 |
+
| **Recall@5** | % where correct translation is in top 5 results |
|
| 19 |
+
| **Recall@10** | % where correct translation is in top 10 results |
|
| 20 |
+
| **MRR** | Mean Reciprocal Rank (average of 1/rank of correct answer) |
|
| 21 |
+
|
| 22 |
+
## Corpus
|
| 23 |
+
|
| 24 |
+
The benchmark includes diverse texts from Sefaria with English translations:
|
| 25 |
+
|
| 26 |
+
- **Talmud**: Bavli and Yerushalmi (Aramaic + Hebrew)
|
| 27 |
+
- **Mishnah**: All tractates (Rabbinic Hebrew)
|
| 28 |
+
- **Midrash**: Midrash Rabbah (Hebrew/Aramaic)
|
| 29 |
+
- **Tanakh Commentary**: Rashi and Ramban on Tanakh (Hebrew)
|
| 30 |
+
- **Hasidic/Kabbalistic**: Likutei Moharan, Tomer Devorah (Hebrew)
|
| 31 |
+
- **Halacha**: Sefer HaHinuch, Intro to Shev Shmateta (Hebrew)
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
+
|
| 35 |
+
1. Select a model from the curated list or enter any Hugging Face model ID
|
| 36 |
+
2. Click "Run Evaluation"
|
| 37 |
+
3. View results and compare with the leaderboard
|
| 38 |
+
|
| 39 |
+
## Models
|
| 40 |
+
|
| 41 |
+
### Curated Models
|
| 42 |
+
- `intfloat/multilingual-e5-large`
|
| 43 |
+
- `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`
|
| 44 |
+
- `BAAI/bge-m3`
|
| 45 |
+
|
| 46 |
+
You can also evaluate any sentence-transformer compatible model from Hugging Face Hub.
|
| 47 |
+
|
| 48 |
+
## Local Development
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
pip install -r requirements.txt
|
| 52 |
+
python app.py
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## License
|
| 56 |
+
|
| 57 |
+
MIT
|
| 58 |
+
|
app.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio interface for Rabbinic Hebrew/Aramaic Embedding Evaluation.
|
| 3 |
+
|
| 4 |
+
A Hugging Face Space for evaluating embedding models on cross-lingual
|
| 5 |
+
retrieval between Hebrew/Aramaic source texts and English translations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import plotly.graph_objects as go
|
| 16 |
+
|
| 17 |
+
from data_loader import load_benchmark_dataset, get_benchmark_stats
|
| 18 |
+
from models import (
|
| 19 |
+
CURATED_MODELS,
|
| 20 |
+
API_MODELS,
|
| 21 |
+
ALL_MODELS,
|
| 22 |
+
get_curated_model_choices,
|
| 23 |
+
get_api_model_choices,
|
| 24 |
+
get_all_model_choices,
|
| 25 |
+
load_model,
|
| 26 |
+
validate_model_id,
|
| 27 |
+
is_api_model,
|
| 28 |
+
requires_api_key,
|
| 29 |
+
api_key_optional,
|
| 30 |
+
get_api_key_type,
|
| 31 |
+
get_api_key_env_var,
|
| 32 |
+
)
|
| 33 |
+
from evaluation import (
|
| 34 |
+
EvaluationResults,
|
| 35 |
+
evaluate_model,
|
| 36 |
+
evaluate_model_streaming,
|
| 37 |
+
compute_similarity_matrix,
|
| 38 |
+
get_rank_distribution,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Paths
|
| 42 |
+
BENCHMARK_PATH = "benchmark_data/benchmark.json"
|
| 43 |
+
LEADERBOARD_PATH = "benchmark_data/leaderboard.json"
|
| 44 |
+
|
| 45 |
+
# Global state
|
| 46 |
+
_benchmark_data = None
|
| 47 |
+
_leaderboard = []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_benchmark():
|
| 51 |
+
"""Load benchmark data, with fallback to sample data."""
|
| 52 |
+
global _benchmark_data
|
| 53 |
+
|
| 54 |
+
if _benchmark_data is not None:
|
| 55 |
+
return _benchmark_data
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
_benchmark_data = load_benchmark_dataset(BENCHMARK_PATH)
|
| 59 |
+
print(f"Loaded {len(_benchmark_data)} benchmark pairs")
|
| 60 |
+
except FileNotFoundError:
|
| 61 |
+
print("Benchmark not found, using sample data")
|
| 62 |
+
# Create minimal sample data for testing
|
| 63 |
+
_benchmark_data = [
|
| 64 |
+
{
|
| 65 |
+
"ref": "Sample.1",
|
| 66 |
+
"he": "בראשית ברא אלהים את השמים ואת הארץ",
|
| 67 |
+
"en": "In the beginning God created the heaven and the earth",
|
| 68 |
+
"category": "Sample",
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"ref": "Sample.2",
|
| 72 |
+
"he": "והארץ היתה תהו ובהו וחשך על פני תהום",
|
| 73 |
+
"en": "And the earth was without form, and void; and darkness was upon the face of the deep",
|
| 74 |
+
"category": "Sample",
|
| 75 |
+
},
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
return _benchmark_data
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_leaderboard():
|
| 82 |
+
"""Load saved leaderboard results."""
|
| 83 |
+
global _leaderboard
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
with open(LEADERBOARD_PATH, "r") as f:
|
| 87 |
+
_leaderboard = json.load(f)
|
| 88 |
+
except FileNotFoundError:
|
| 89 |
+
_leaderboard = []
|
| 90 |
+
|
| 91 |
+
return _leaderboard
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def save_leaderboard():
|
| 95 |
+
"""Save leaderboard to file."""
|
| 96 |
+
global _leaderboard
|
| 97 |
+
|
| 98 |
+
Path(LEADERBOARD_PATH).parent.mkdir(parents=True, exist_ok=True)
|
| 99 |
+
with open(LEADERBOARD_PATH, "w") as f:
|
| 100 |
+
json.dump(_leaderboard, f, indent=2)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def add_to_leaderboard(results: EvaluationResults):
|
| 104 |
+
"""Add evaluation results to leaderboard."""
|
| 105 |
+
global _leaderboard
|
| 106 |
+
|
| 107 |
+
entry = results.to_dict()
|
| 108 |
+
entry["timestamp"] = datetime.now().isoformat()
|
| 109 |
+
|
| 110 |
+
# Remove existing entry for same model
|
| 111 |
+
_leaderboard = [e for e in _leaderboard if e["model_id"] != results.model_id]
|
| 112 |
+
_leaderboard.append(entry)
|
| 113 |
+
|
| 114 |
+
# Sort by MRR descending
|
| 115 |
+
_leaderboard.sort(key=lambda x: x["mrr"], reverse=True)
|
| 116 |
+
|
| 117 |
+
save_leaderboard()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def format_leaderboard_df():
|
| 121 |
+
"""Format leaderboard as pandas DataFrame for display."""
|
| 122 |
+
load_leaderboard()
|
| 123 |
+
|
| 124 |
+
if not _leaderboard:
|
| 125 |
+
return pd.DataFrame(columns=[
|
| 126 |
+
"#", "Model", "MRR", "R@1", "R@5", "R@10",
|
| 127 |
+
"Bitext", "TrueSim", "RandSim", "N"
|
| 128 |
+
])
|
| 129 |
+
|
| 130 |
+
rows = []
|
| 131 |
+
for i, entry in enumerate(_leaderboard, 1):
|
| 132 |
+
rows.append({
|
| 133 |
+
"#": i,
|
| 134 |
+
"Model": entry.get("model_name", entry["model_id"]),
|
| 135 |
+
"MRR": f"{entry['mrr']:.3f}",
|
| 136 |
+
"R@1": f"{entry['recall_at_1']:.1%}",
|
| 137 |
+
"R@5": f"{entry['recall_at_5']:.1%}",
|
| 138 |
+
"R@10": f"{entry['recall_at_10']:.1%}",
|
| 139 |
+
"Bitext": f"{entry['bitext_accuracy']:.1%}",
|
| 140 |
+
"TrueSim": f"{entry['avg_true_pair_similarity']:.3f}",
|
| 141 |
+
"RandSim": f"{entry['avg_random_pair_similarity']:.3f}",
|
| 142 |
+
"N": entry["num_pairs"],
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
return pd.DataFrame(rows)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def run_evaluation(
|
| 149 |
+
model_choice: str,
|
| 150 |
+
custom_model_id: str,
|
| 151 |
+
api_key: str,
|
| 152 |
+
max_pairs: int,
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Run evaluation for the selected model (generator for streaming status updates).
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
model_choice: Selected curated model or "custom"
|
| 159 |
+
custom_model_id: Custom model ID if selected
|
| 160 |
+
api_key: API key for API-based models
|
| 161 |
+
max_pairs: Maximum pairs to evaluate
|
| 162 |
+
|
| 163 |
+
Yields:
|
| 164 |
+
Tuples of (status, results, leaderboard)
|
| 165 |
+
"""
|
| 166 |
+
# Helper to yield status updates
|
| 167 |
+
def status_update(msg):
|
| 168 |
+
return (msg, gr.update(), gr.update())
|
| 169 |
+
|
| 170 |
+
# Determine which model to use
|
| 171 |
+
if model_choice == "custom":
|
| 172 |
+
model_id = custom_model_id.strip()
|
| 173 |
+
is_valid, error = validate_model_id(model_id)
|
| 174 |
+
if not is_valid:
|
| 175 |
+
yield (
|
| 176 |
+
f"❌ {error}",
|
| 177 |
+
f"❌ Invalid model ID: {error}",
|
| 178 |
+
format_leaderboard_df(),
|
| 179 |
+
)
|
| 180 |
+
return
|
| 181 |
+
else:
|
| 182 |
+
model_id = model_choice
|
| 183 |
+
|
| 184 |
+
# Check if API key is required but not provided
|
| 185 |
+
if requires_api_key(model_id):
|
| 186 |
+
api_key = api_key.strip() if api_key else ""
|
| 187 |
+
env_var = get_api_key_env_var(model_id)
|
| 188 |
+
key_type = get_api_key_type(model_id)
|
| 189 |
+
|
| 190 |
+
# Skip API key check for models that support Application Default Credentials
|
| 191 |
+
if not api_key and not os.environ.get(env_var) and not api_key_optional(model_id):
|
| 192 |
+
yield (
|
| 193 |
+
"❌ API key required",
|
| 194 |
+
f"❌ API key required for {model_id}. Please enter your {key_type.upper()} API key or set the {env_var} environment variable.",
|
| 195 |
+
format_leaderboard_df(),
|
| 196 |
+
)
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
yield status_update(f"⏳ Loading benchmark data...")
|
| 200 |
+
benchmark = load_benchmark()
|
| 201 |
+
|
| 202 |
+
if max_pairs and max_pairs < len(benchmark):
|
| 203 |
+
benchmark = benchmark[:max_pairs]
|
| 204 |
+
|
| 205 |
+
yield status_update(f"⏳ Loading model: {model_id}...")
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
# Pass API key for API-based models
|
| 209 |
+
model = load_model(model_id, api_key=api_key if api_key else None)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
yield (
|
| 212 |
+
"❌ Model load failed",
|
| 213 |
+
f"❌ Failed to load model: {str(e)}",
|
| 214 |
+
format_leaderboard_df(),
|
| 215 |
+
)
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Stream progress updates during evaluation
|
| 219 |
+
try:
|
| 220 |
+
results = None
|
| 221 |
+
for item in evaluate_model_streaming(model, benchmark, batch_size=32):
|
| 222 |
+
if isinstance(item, str):
|
| 223 |
+
# Progress update
|
| 224 |
+
yield status_update(item)
|
| 225 |
+
else:
|
| 226 |
+
# Final results
|
| 227 |
+
results = item
|
| 228 |
+
except Exception as e:
|
| 229 |
+
yield (
|
| 230 |
+
"❌ Evaluation failed",
|
| 231 |
+
f"❌ Evaluation failed: {str(e)}",
|
| 232 |
+
format_leaderboard_df(),
|
| 233 |
+
)
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
yield status_update("⏳ Saving results...")
|
| 237 |
+
add_to_leaderboard(results)
|
| 238 |
+
|
| 239 |
+
# Format results summary
|
| 240 |
+
summary = f"""## Results for {results.model_name}
|
| 241 |
+
|
| 242 |
+
| Metric | Value |
|
| 243 |
+
|--------|-------|
|
| 244 |
+
| **MRR** | {results.mrr:.4f} |
|
| 245 |
+
| **Recall@1** | {results.recall_at_1:.1%} |
|
| 246 |
+
| **Recall@5** | {results.recall_at_5:.1%} |
|
| 247 |
+
| **Recall@10** | {results.recall_at_10:.1%} |
|
| 248 |
+
| **Bitext Accuracy** | {results.bitext_accuracy:.1%} |
|
| 249 |
+
| **Avg True Pair Sim** | {results.avg_true_pair_similarity:.4f} |
|
| 250 |
+
| **Avg Random Pair Sim** | {results.avg_random_pair_similarity:.4f} |
|
| 251 |
+
| **Pairs Evaluated** | {results.num_pairs:,} |
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
# Final yield with all results (clear status)
|
| 255 |
+
yield (
|
| 256 |
+
"✅ Complete!",
|
| 257 |
+
summary,
|
| 258 |
+
format_leaderboard_df(),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def create_leaderboard_comparison():
|
| 263 |
+
"""Create comparison chart of all models on leaderboard."""
|
| 264 |
+
load_leaderboard()
|
| 265 |
+
|
| 266 |
+
if len(_leaderboard) < 2:
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
models = [e.get("model_name", e["model_id"]) for e in _leaderboard]
|
| 270 |
+
mrr = [e["mrr"] for e in _leaderboard]
|
| 271 |
+
r1 = [e["recall_at_1"] for e in _leaderboard]
|
| 272 |
+
r5 = [e["recall_at_5"] for e in _leaderboard]
|
| 273 |
+
r10 = [e["recall_at_10"] for e in _leaderboard]
|
| 274 |
+
bitext = [e["bitext_accuracy"] for e in _leaderboard]
|
| 275 |
+
|
| 276 |
+
fig = go.Figure()
|
| 277 |
+
|
| 278 |
+
fig.add_trace(go.Bar(name="MRR", x=models, y=mrr, marker_color="#2E86AB"))
|
| 279 |
+
fig.add_trace(go.Bar(name="R@1", x=models, y=r1, marker_color="#A23B72"))
|
| 280 |
+
fig.add_trace(go.Bar(name="R@5", x=models, y=r5, marker_color="#F18F01"))
|
| 281 |
+
fig.add_trace(go.Bar(name="R@10", x=models, y=r10, marker_color="#C73E1D"))
|
| 282 |
+
fig.add_trace(go.Bar(name="Bitext Acc", x=models, y=bitext, marker_color="#6B5B95"))
|
| 283 |
+
|
| 284 |
+
fig.update_layout(
|
| 285 |
+
title="Model Comparison",
|
| 286 |
+
yaxis_title="Score",
|
| 287 |
+
yaxis_range=[0, 1],
|
| 288 |
+
barmode="group",
|
| 289 |
+
template="plotly_white",
|
| 290 |
+
height=400,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return fig
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def update_model_inputs_visibility(choice):
|
| 297 |
+
"""Show/hide custom model input and API key based on selection."""
|
| 298 |
+
show_custom = (choice == "custom")
|
| 299 |
+
show_api_key = requires_api_key(choice) if choice != "custom" else False
|
| 300 |
+
|
| 301 |
+
# Update API key label based on model type
|
| 302 |
+
if show_api_key:
|
| 303 |
+
key_type = get_api_key_type(choice)
|
| 304 |
+
env_var = get_api_key_env_var(choice)
|
| 305 |
+
is_optional = api_key_optional(choice)
|
| 306 |
+
|
| 307 |
+
if key_type == "voyage":
|
| 308 |
+
label = "Voyage AI API Key"
|
| 309 |
+
placeholder = f"Enter your Voyage AI API key (or set {env_var} env var)"
|
| 310 |
+
elif key_type == "gemini":
|
| 311 |
+
label = "Gemini API Key (optional if using gcloud)"
|
| 312 |
+
placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
|
| 313 |
+
else:
|
| 314 |
+
label = "OpenAI API Key"
|
| 315 |
+
placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
|
| 316 |
+
return (
|
| 317 |
+
gr.update(visible=show_custom),
|
| 318 |
+
gr.update(visible=show_api_key, label=label, placeholder=placeholder),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
return (
|
| 322 |
+
gr.update(visible=show_custom),
|
| 323 |
+
gr.update(visible=show_api_key),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# Build the Gradio interface
|
| 328 |
+
def create_app():
|
| 329 |
+
"""Create and return the Gradio app."""
|
| 330 |
+
|
| 331 |
+
# Get all model choices - local models first, then API models
|
| 332 |
+
model_choices = []
|
| 333 |
+
|
| 334 |
+
# Local models
|
| 335 |
+
for model_id, info in CURATED_MODELS.items():
|
| 336 |
+
model_choices.append((f"🖥️ {info['name']}", model_id))
|
| 337 |
+
|
| 338 |
+
# API models
|
| 339 |
+
for model_id, info in API_MODELS.items():
|
| 340 |
+
model_choices.append((f"🌐 {info['name']}", model_id))
|
| 341 |
+
|
| 342 |
+
# Custom option
|
| 343 |
+
model_choices.append(("⚙️ Custom Model (enter ID below)", "custom"))
|
| 344 |
+
|
| 345 |
+
# Load initial data
|
| 346 |
+
load_benchmark()
|
| 347 |
+
load_leaderboard()
|
| 348 |
+
benchmark_stats = get_benchmark_stats(_benchmark_data) if _benchmark_data else {}
|
| 349 |
+
|
| 350 |
+
with gr.Blocks(
|
| 351 |
+
title="Rabbinic Embedding Benchmark",
|
| 352 |
+
theme=gr.themes.Soft(
|
| 353 |
+
primary_hue="blue",
|
| 354 |
+
secondary_hue="orange",
|
| 355 |
+
font=gr.themes.GoogleFont("Source Sans Pro"),
|
| 356 |
+
),
|
| 357 |
+
css="""
|
| 358 |
+
.main-header {
|
| 359 |
+
text-align: center;
|
| 360 |
+
margin-bottom: 1rem;
|
| 361 |
+
}
|
| 362 |
+
.stats-box {
|
| 363 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 364 |
+
color: white;
|
| 365 |
+
padding: 1rem;
|
| 366 |
+
border-radius: 8px;
|
| 367 |
+
margin: 0.5rem 0;
|
| 368 |
+
}
|
| 369 |
+
""",
|
| 370 |
+
) as app:
|
| 371 |
+
|
| 372 |
+
gr.Markdown(
|
| 373 |
+
"""
|
| 374 |
+
# 📚 Rabbinic Hebrew/Aramaic Embedding Benchmark
|
| 375 |
+
|
| 376 |
+
Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic
|
| 377 |
+
source texts and their English translations from Sefaria.
|
| 378 |
+
|
| 379 |
+
**How it works:** Given a Hebrew/Aramaic text, can the model find its correct
|
| 380 |
+
English translation from a pool of candidates? Models that excel at this task
|
| 381 |
+
produce high-quality embeddings for Rabbinic literature.
|
| 382 |
+
""",
|
| 383 |
+
elem_classes=["main-header"],
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
with gr.Row():
|
| 387 |
+
with gr.Column(scale=1):
|
| 388 |
+
gr.Markdown(f"""
|
| 389 |
+
### 📊 Benchmark Stats
|
| 390 |
+
- **Total Pairs:** {benchmark_stats.get('total_pairs', 'N/A'):,}
|
| 391 |
+
- **Categories:** {len(benchmark_stats.get('categories', {}))}
|
| 392 |
+
- **Avg Hebrew Length:** {benchmark_stats.get('avg_he_length', 0):.0f} chars
|
| 393 |
+
""")
|
| 394 |
+
|
| 395 |
+
with gr.Column(scale=1):
|
| 396 |
+
gr.Markdown("""
|
| 397 |
+
### 📏 Metrics
|
| 398 |
+
- **MRR:** Mean Reciprocal Rank
|
| 399 |
+
- **R@k:** Recall at k (correct in top k)
|
| 400 |
+
- **Bitext Acc:** True vs random pair classification
|
| 401 |
+
""")
|
| 402 |
+
|
| 403 |
+
gr.Markdown("---")
|
| 404 |
+
|
| 405 |
+
with gr.Tabs():
|
| 406 |
+
with gr.TabItem("🔬 Evaluate Model"):
|
| 407 |
+
with gr.Row():
|
| 408 |
+
with gr.Column(scale=2):
|
| 409 |
+
model_dropdown = gr.Dropdown(
|
| 410 |
+
choices=model_choices,
|
| 411 |
+
value=model_choices[0][1],
|
| 412 |
+
label="Select Model",
|
| 413 |
+
info="Choose a curated model or enter a custom Hugging Face model ID",
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
custom_model_input = gr.Textbox(
|
| 417 |
+
label="Custom Model ID",
|
| 418 |
+
placeholder="e.g., organization/model-name",
|
| 419 |
+
visible=False,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
api_key_input = gr.Textbox(
|
| 423 |
+
label="API Key",
|
| 424 |
+
placeholder="Enter your API key (or set appropriate env var)",
|
| 425 |
+
type="password",
|
| 426 |
+
visible=False,
|
| 427 |
+
info="Required for API-based models (OpenAI, Voyage AI). Your key is not stored.",
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
total_pairs = benchmark_stats.get('total_pairs', 1000)
|
| 431 |
+
max_pairs_slider = gr.Slider(
|
| 432 |
+
minimum=100,
|
| 433 |
+
maximum=total_pairs,
|
| 434 |
+
value=total_pairs,
|
| 435 |
+
step=100,
|
| 436 |
+
label="Max Pairs to Evaluate",
|
| 437 |
+
info="Use fewer pairs for faster evaluation",
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
with gr.Column(scale=3):
|
| 441 |
+
evaluate_btn = gr.Button(
|
| 442 |
+
"🚀 Run Evaluation",
|
| 443 |
+
variant="primary",
|
| 444 |
+
size="lg",
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
status_text = gr.Markdown("")
|
| 448 |
+
|
| 449 |
+
results_markdown = gr.Markdown("")
|
| 450 |
+
|
| 451 |
+
with gr.TabItem("🏆 Leaderboard"):
|
| 452 |
+
leaderboard_table = gr.Dataframe(
|
| 453 |
+
value=format_leaderboard_df(),
|
| 454 |
+
label="Model Rankings",
|
| 455 |
+
interactive=False,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
refresh_btn = gr.Button("🔄 Refresh Leaderboard")
|
| 459 |
+
|
| 460 |
+
comparison_plot = gr.Plot(label="Model Comparison")
|
| 461 |
+
|
| 462 |
+
gr.Markdown("""
|
| 463 |
+
---
|
| 464 |
+
### About
|
| 465 |
+
|
| 466 |
+
This benchmark evaluates embedding models for Rabbinic Hebrew and Aramaic texts using
|
| 467 |
+
cross-lingual retrieval.
|
| 468 |
+
|
| 469 |
+
All texts and translations sourced from [Sefaria](https://www.sefaria.org).
|
| 470 |
+
""")
|
| 471 |
+
|
| 472 |
+
# Event handlers
|
| 473 |
+
model_dropdown.change(
|
| 474 |
+
fn=update_model_inputs_visibility,
|
| 475 |
+
inputs=[model_dropdown],
|
| 476 |
+
outputs=[custom_model_input, api_key_input],
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
evaluate_btn.click(
|
| 480 |
+
fn=run_evaluation,
|
| 481 |
+
inputs=[model_dropdown, custom_model_input, api_key_input, max_pairs_slider],
|
| 482 |
+
outputs=[status_text, results_markdown, leaderboard_table],
|
| 483 |
+
show_progress="hidden",
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
refresh_btn.click(
|
| 487 |
+
fn=lambda: (format_leaderboard_df(), create_leaderboard_comparison()),
|
| 488 |
+
outputs=[leaderboard_table, comparison_plot],
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
return app
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# Main entry point
|
| 495 |
+
if __name__ == "__main__":
|
| 496 |
+
app = create_app()
|
| 497 |
+
app.launch()
|
| 498 |
+
|
benchmark-stats.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Total pairs: 3,721
|
| 2 |
+
Categories:
|
| 3 |
+
- Halacha: 160
|
| 4 |
+
- Hasidic/Kabbalistic: 304
|
| 5 |
+
- Jerusalem Talmud: 520
|
| 6 |
+
- Midrash Rabbah: 400
|
| 7 |
+
- Mishnah: 789
|
| 8 |
+
- Mussar/Ethics: 108
|
| 9 |
+
- Philosophy: 240
|
| 10 |
+
- Talmud: 480
|
| 11 |
+
- Tanakh Commentary: 680
|
| 12 |
+
- Targum: 40
|
| 13 |
+
Average Hebrew text length: 650 chars
|
| 14 |
+
Average English text length: 995 chars
|
benchmark_data/benchmark.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build_benchmark.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to build the benchmark dataset from Sefaria API.
|
| 4 |
+
|
| 5 |
+
Run this script to fetch and cache parallel Hebrew/Aramaic-English text pairs
|
| 6 |
+
from Sefaria for use in the embedding evaluation benchmark.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python build_benchmark.py [--max-per-text N] [--total N] [--output PATH]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import requests
|
| 17 |
+
|
| 18 |
+
from data_loader import (
|
| 19 |
+
build_benchmark_dataset,
|
| 20 |
+
get_benchmark_stats,
|
| 21 |
+
get_index_from_sefaria,
|
| 22 |
+
set_sefaria_host,
|
| 23 |
+
get_sefaria_host,
|
| 24 |
+
BENCHMARK_TEXTS,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_name_suggestions(title: str, host: str, limit: int = 5) -> list[str]:
|
| 29 |
+
"""Get name suggestions from the Sefaria name API."""
|
| 30 |
+
try:
|
| 31 |
+
url = f"{host}/api/name/{title}"
|
| 32 |
+
response = requests.get(url, params={"limit": limit, "type": "ref"}, timeout=10)
|
| 33 |
+
if response.status_code == 200:
|
| 34 |
+
data = response.json()
|
| 35 |
+
# Return completions that are refs (book titles)
|
| 36 |
+
completions = data.get("completions", [])
|
| 37 |
+
return completions[:limit]
|
| 38 |
+
except requests.RequestException:
|
| 39 |
+
pass
|
| 40 |
+
return []
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser(
|
| 45 |
+
description="Build Rabbinic embedding benchmark dataset from Sefaria"
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--max-per-text",
|
| 49 |
+
type=int,
|
| 50 |
+
default=40,
|
| 51 |
+
help="Maximum segments per text (default: 40)",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--total",
|
| 55 |
+
type=int,
|
| 56 |
+
default=10000,
|
| 57 |
+
help="Total target segments (default: 10000)",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--output",
|
| 61 |
+
type=str,
|
| 62 |
+
default="benchmark_data/benchmark.json",
|
| 63 |
+
help="Output file path (default: benchmark_data/benchmark.json)",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--dry-run",
|
| 67 |
+
action="store_true",
|
| 68 |
+
help="Show what would be fetched without making API calls",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--host",
|
| 72 |
+
type=str,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Sefaria host URL (default: https://www.sefaria.org, or SEFARIA_HOST env var)",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--check-titles",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="Check all text titles against the API to verify they exist",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
args = parser.parse_args()
|
| 83 |
+
|
| 84 |
+
# Configure Sefaria host if specified
|
| 85 |
+
if args.host:
|
| 86 |
+
set_sefaria_host(args.host)
|
| 87 |
+
|
| 88 |
+
if args.check_titles:
|
| 89 |
+
print("="*60)
|
| 90 |
+
print("Checking Text Titles Against API")
|
| 91 |
+
print("="*60)
|
| 92 |
+
host = get_sefaria_host()
|
| 93 |
+
print(f"\nSefaria host: {host}\n")
|
| 94 |
+
|
| 95 |
+
valid = []
|
| 96 |
+
invalid = []
|
| 97 |
+
suggestions = {}
|
| 98 |
+
|
| 99 |
+
for category_key, category_info in BENCHMARK_TEXTS.items():
|
| 100 |
+
category_name = category_info["category"]
|
| 101 |
+
print(f"\n{category_name}:")
|
| 102 |
+
|
| 103 |
+
for text in category_info["texts"]:
|
| 104 |
+
index = get_index_from_sefaria(text)
|
| 105 |
+
if index:
|
| 106 |
+
print(f" ✓ {text}")
|
| 107 |
+
valid.append(text)
|
| 108 |
+
else:
|
| 109 |
+
# Get suggestions from name API
|
| 110 |
+
suggested = get_name_suggestions(text, host)
|
| 111 |
+
suggestions[text] = suggested
|
| 112 |
+
if suggested:
|
| 113 |
+
print(f" ✗ {text} → Did you mean: {suggested[0]}?")
|
| 114 |
+
else:
|
| 115 |
+
print(f" ✗ {text}")
|
| 116 |
+
invalid.append(text)
|
| 117 |
+
|
| 118 |
+
print("\n" + "="*60)
|
| 119 |
+
print("SUMMARY")
|
| 120 |
+
print("="*60)
|
| 121 |
+
print(f"\nValid titles: {len(valid)}")
|
| 122 |
+
print(f"Invalid titles: {len(invalid)}")
|
| 123 |
+
|
| 124 |
+
if invalid:
|
| 125 |
+
print(f"\nInvalid titles that need fixing:")
|
| 126 |
+
for title in invalid:
|
| 127 |
+
suggested = suggestions.get(title, [])
|
| 128 |
+
if suggested:
|
| 129 |
+
print(f" - {title}")
|
| 130 |
+
print(f" Suggestions: {', '.join(suggested[:3])}")
|
| 131 |
+
else:
|
| 132 |
+
print(f" - {title} (no suggestions found)")
|
| 133 |
+
else:
|
| 134 |
+
print("\nAll titles are valid!")
|
| 135 |
+
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
if args.dry_run:
|
| 139 |
+
print("DRY RUN: Would fetch from these texts:\n")
|
| 140 |
+
print(f"Sefaria host: {get_sefaria_host()}")
|
| 141 |
+
total_texts = 0
|
| 142 |
+
for category_key, category_info in BENCHMARK_TEXTS.items():
|
| 143 |
+
print(f"\n{category_info['category']} ({category_info['language']}):")
|
| 144 |
+
for text in category_info["texts"]:
|
| 145 |
+
print(f" - {text}")
|
| 146 |
+
total_texts += 1
|
| 147 |
+
print(f"\nTotal texts: {total_texts}")
|
| 148 |
+
print(f"Target segments per text: {args.max_per_text}")
|
| 149 |
+
print(f"Total target segments: {args.total}")
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
print("="*60)
|
| 153 |
+
print("Building Rabbinic Embedding Benchmark Dataset")
|
| 154 |
+
print("="*60)
|
| 155 |
+
print(f"\nSettings:")
|
| 156 |
+
print(f" Sefaria host: {get_sefaria_host()}")
|
| 157 |
+
print(f" Max segments per text: {args.max_per_text}")
|
| 158 |
+
print(f" Total target: {args.total}")
|
| 159 |
+
print(f" Output: {args.output}")
|
| 160 |
+
print()
|
| 161 |
+
|
| 162 |
+
# Ensure output directory exists
|
| 163 |
+
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
# Build the dataset
|
| 166 |
+
pairs = build_benchmark_dataset(
|
| 167 |
+
output_path=args.output,
|
| 168 |
+
segments_per_text=args.max_per_text,
|
| 169 |
+
total_target=args.total,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Print final statistics
|
| 173 |
+
stats = get_benchmark_stats(pairs)
|
| 174 |
+
|
| 175 |
+
print("\n" + "="*60)
|
| 176 |
+
print("BENCHMARK COMPLETE")
|
| 177 |
+
print("="*60)
|
| 178 |
+
print(f"\nFinal Statistics:")
|
| 179 |
+
print(f" Total pairs: {stats['total_pairs']:,}")
|
| 180 |
+
print(f" Categories:")
|
| 181 |
+
for cat, count in sorted(stats["categories"].items()):
|
| 182 |
+
print(f" - {cat}: {count:,}")
|
| 183 |
+
print(f" Average Hebrew text length: {stats['avg_he_length']:.0f} chars")
|
| 184 |
+
print(f" Average English text length: {stats['avg_en_length']:.0f} chars")
|
| 185 |
+
print(f"\nSaved to: {args.output}")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
main()
|
| 190 |
+
|
check_token_limits.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Check token limits for benchmark data entries.
|
| 3 |
+
|
| 4 |
+
This script scans the benchmark dataset and flags entries that exceed
|
| 5 |
+
the 8192 token limit used by OpenAI embedding models (text-embedding-ada-002,
|
| 6 |
+
text-embedding-3-small, text-embedding-3-large).
|
| 7 |
+
|
| 8 |
+
Uses tiktoken with the cl100k_base encoding, which is the tokenizer used
|
| 9 |
+
by OpenAI's embedding models.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
import tiktoken
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# OpenAI embedding models use cl100k_base encoding
|
| 21 |
+
ENCODING_NAME = "cl100k_base"
|
| 22 |
+
MAX_TOKENS = 8192
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TokenOverage:
|
| 27 |
+
"""Represents an entry that exceeds the token limit."""
|
| 28 |
+
ref: str
|
| 29 |
+
category: str
|
| 30 |
+
field: str # 'he', 'en', or 'combined'
|
| 31 |
+
token_count: int
|
| 32 |
+
char_count: int
|
| 33 |
+
text_preview: str # First N characters of the text
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def count_tokens(text: str, encoding: tiktoken.Encoding) -> int:
|
| 37 |
+
"""Count the number of tokens in a text string."""
|
| 38 |
+
return len(encoding.encode(text))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def check_entry(
|
| 42 |
+
entry: dict,
|
| 43 |
+
encoding: tiktoken.Encoding,
|
| 44 |
+
max_tokens: int = MAX_TOKENS,
|
| 45 |
+
preview_length: int = 100
|
| 46 |
+
) -> list[TokenOverage]:
|
| 47 |
+
"""
|
| 48 |
+
Check a single entry for token limit violations.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
entry: Dictionary with 'ref', 'he', 'en', 'category' keys
|
| 52 |
+
encoding: tiktoken encoding to use
|
| 53 |
+
max_tokens: Maximum allowed tokens
|
| 54 |
+
preview_length: Number of characters to include in preview
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
List of TokenOverage objects for any fields exceeding the limit
|
| 58 |
+
"""
|
| 59 |
+
overages = []
|
| 60 |
+
|
| 61 |
+
ref = entry.get("ref", "unknown")
|
| 62 |
+
category = entry.get("category", "unknown")
|
| 63 |
+
|
| 64 |
+
for field in ["he", "en"]:
|
| 65 |
+
text = entry.get(field, "")
|
| 66 |
+
if not text:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
token_count = count_tokens(text, encoding)
|
| 70 |
+
|
| 71 |
+
if token_count > max_tokens:
|
| 72 |
+
preview = text[:preview_length] + "..." if len(text) > preview_length else text
|
| 73 |
+
overages.append(TokenOverage(
|
| 74 |
+
ref=ref,
|
| 75 |
+
category=category,
|
| 76 |
+
field=field,
|
| 77 |
+
token_count=token_count,
|
| 78 |
+
char_count=len(text),
|
| 79 |
+
text_preview=preview
|
| 80 |
+
))
|
| 81 |
+
|
| 82 |
+
return overages
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def check_benchmark_data(
|
| 86 |
+
data_path: str,
|
| 87 |
+
max_tokens: int = MAX_TOKENS,
|
| 88 |
+
verbose: bool = False
|
| 89 |
+
) -> tuple[list[TokenOverage], dict]:
|
| 90 |
+
"""
|
| 91 |
+
Check all entries in the benchmark dataset for token limit violations.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
data_path: Path to the benchmark JSON file
|
| 95 |
+
max_tokens: Maximum allowed tokens (default: 8192)
|
| 96 |
+
verbose: Print progress information
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Tuple of (list of overages, statistics dict)
|
| 100 |
+
"""
|
| 101 |
+
# Load the encoding
|
| 102 |
+
if verbose:
|
| 103 |
+
print(f"Loading tokenizer: {ENCODING_NAME}")
|
| 104 |
+
encoding = tiktoken.get_encoding(ENCODING_NAME)
|
| 105 |
+
|
| 106 |
+
# Load the data
|
| 107 |
+
if verbose:
|
| 108 |
+
print(f"Loading data from: {data_path}")
|
| 109 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 110 |
+
data = json.load(f)
|
| 111 |
+
|
| 112 |
+
if verbose:
|
| 113 |
+
print(f"Checking {len(data)} entries for token limit ({max_tokens} tokens)...")
|
| 114 |
+
|
| 115 |
+
# Check all entries
|
| 116 |
+
all_overages = []
|
| 117 |
+
token_counts_he = []
|
| 118 |
+
token_counts_en = []
|
| 119 |
+
|
| 120 |
+
for i, entry in enumerate(data):
|
| 121 |
+
if verbose and (i + 1) % 1000 == 0:
|
| 122 |
+
print(f" Processed {i + 1}/{len(data)} entries...")
|
| 123 |
+
|
| 124 |
+
# Count tokens for statistics
|
| 125 |
+
he_text = entry.get("he", "")
|
| 126 |
+
en_text = entry.get("en", "")
|
| 127 |
+
|
| 128 |
+
if he_text:
|
| 129 |
+
token_counts_he.append(count_tokens(he_text, encoding))
|
| 130 |
+
if en_text:
|
| 131 |
+
token_counts_en.append(count_tokens(en_text, encoding))
|
| 132 |
+
|
| 133 |
+
# Check for overages
|
| 134 |
+
overages = check_entry(entry, encoding, max_tokens)
|
| 135 |
+
all_overages.extend(overages)
|
| 136 |
+
|
| 137 |
+
# Compute statistics
|
| 138 |
+
stats = {
|
| 139 |
+
"total_entries": len(data),
|
| 140 |
+
"entries_with_overages": len(set(o.ref for o in all_overages)),
|
| 141 |
+
"total_overages": len(all_overages),
|
| 142 |
+
"he_overages": len([o for o in all_overages if o.field == "he"]),
|
| 143 |
+
"en_overages": len([o for o in all_overages if o.field == "en"]),
|
| 144 |
+
"max_tokens_checked": max_tokens,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if token_counts_he:
|
| 148 |
+
stats["he_token_stats"] = {
|
| 149 |
+
"min": min(token_counts_he),
|
| 150 |
+
"max": max(token_counts_he),
|
| 151 |
+
"avg": sum(token_counts_he) / len(token_counts_he),
|
| 152 |
+
"total_entries": len(token_counts_he),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
if token_counts_en:
|
| 156 |
+
stats["en_token_stats"] = {
|
| 157 |
+
"min": min(token_counts_en),
|
| 158 |
+
"max": max(token_counts_en),
|
| 159 |
+
"avg": sum(token_counts_en) / len(token_counts_en),
|
| 160 |
+
"total_entries": len(token_counts_en),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
return all_overages, stats
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def print_report(overages: list[TokenOverage], stats: dict) -> None:
|
| 167 |
+
"""Print a formatted report of token limit violations."""
|
| 168 |
+
print("\n" + "=" * 70)
|
| 169 |
+
print("TOKEN LIMIT CHECK REPORT")
|
| 170 |
+
print("=" * 70)
|
| 171 |
+
|
| 172 |
+
print(f"\nDataset Summary:")
|
| 173 |
+
print(f" Total entries checked: {stats['total_entries']:,}")
|
| 174 |
+
print(f" Token limit: {stats['max_tokens_checked']:,}")
|
| 175 |
+
|
| 176 |
+
if "he_token_stats" in stats:
|
| 177 |
+
he_stats = stats["he_token_stats"]
|
| 178 |
+
print(f"\nHebrew/Aramaic Token Statistics:")
|
| 179 |
+
print(f" Min tokens: {he_stats['min']:,}")
|
| 180 |
+
print(f" Max tokens: {he_stats['max']:,}")
|
| 181 |
+
print(f" Avg tokens: {he_stats['avg']:.1f}")
|
| 182 |
+
|
| 183 |
+
if "en_token_stats" in stats:
|
| 184 |
+
en_stats = stats["en_token_stats"]
|
| 185 |
+
print(f"\nEnglish Token Statistics:")
|
| 186 |
+
print(f" Min tokens: {en_stats['min']:,}")
|
| 187 |
+
print(f" Max tokens: {en_stats['max']:,}")
|
| 188 |
+
print(f" Avg tokens: {en_stats['avg']:.1f}")
|
| 189 |
+
|
| 190 |
+
print(f"\nOverage Summary:")
|
| 191 |
+
print(f" Entries exceeding limit: {stats['entries_with_overages']:,}")
|
| 192 |
+
print(f" Total field overages: {stats['total_overages']:,}")
|
| 193 |
+
print(f" - Hebrew/Aramaic fields: {stats['he_overages']:,}")
|
| 194 |
+
print(f" - English fields: {stats['en_overages']:,}")
|
| 195 |
+
|
| 196 |
+
if overages:
|
| 197 |
+
print("\n" + "-" * 70)
|
| 198 |
+
print("FLAGGED ENTRIES (exceeding token limit):")
|
| 199 |
+
print("-" * 70)
|
| 200 |
+
|
| 201 |
+
# Group by category
|
| 202 |
+
by_category = {}
|
| 203 |
+
for overage in overages:
|
| 204 |
+
if overage.category not in by_category:
|
| 205 |
+
by_category[overage.category] = []
|
| 206 |
+
by_category[overage.category].append(overage)
|
| 207 |
+
|
| 208 |
+
for category, category_overages in sorted(by_category.items()):
|
| 209 |
+
print(f"\n[{category}] - {len(category_overages)} overage(s)")
|
| 210 |
+
for overage in category_overages:
|
| 211 |
+
print(f"\n Reference: {overage.ref}")
|
| 212 |
+
print(f" Field: {overage.field}")
|
| 213 |
+
print(f" Token count: {overage.token_count:,} (limit: {stats['max_tokens_checked']:,})")
|
| 214 |
+
print(f" Character count: {overage.char_count:,}")
|
| 215 |
+
print(f" Preview: {overage.text_preview}")
|
| 216 |
+
else:
|
| 217 |
+
print("\n✓ No entries exceed the token limit!")
|
| 218 |
+
|
| 219 |
+
print("\n" + "=" * 70)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def save_report(
|
| 223 |
+
overages: list[TokenOverage],
|
| 224 |
+
stats: dict,
|
| 225 |
+
output_path: str
|
| 226 |
+
) -> None:
|
| 227 |
+
"""Save the report to a JSON file."""
|
| 228 |
+
report = {
|
| 229 |
+
"stats": stats,
|
| 230 |
+
"overages": [
|
| 231 |
+
{
|
| 232 |
+
"ref": o.ref,
|
| 233 |
+
"category": o.category,
|
| 234 |
+
"field": o.field,
|
| 235 |
+
"token_count": o.token_count,
|
| 236 |
+
"char_count": o.char_count,
|
| 237 |
+
"text_preview": o.text_preview,
|
| 238 |
+
}
|
| 239 |
+
for o in overages
|
| 240 |
+
]
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 244 |
+
json.dump(report, f, ensure_ascii=False, indent=2)
|
| 245 |
+
|
| 246 |
+
print(f"\nReport saved to: {output_path}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def main():
|
| 250 |
+
parser = argparse.ArgumentParser(
|
| 251 |
+
description="Check benchmark data for entries exceeding OpenAI embedding token limits."
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--data",
|
| 255 |
+
"-d",
|
| 256 |
+
type=str,
|
| 257 |
+
default="benchmark_data/benchmark.json",
|
| 258 |
+
help="Path to the benchmark JSON file (default: benchmark_data/benchmark.json)"
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--max-tokens",
|
| 262 |
+
"-m",
|
| 263 |
+
type=int,
|
| 264 |
+
default=MAX_TOKENS,
|
| 265 |
+
help=f"Maximum allowed tokens (default: {MAX_TOKENS})"
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--output",
|
| 269 |
+
"-o",
|
| 270 |
+
type=str,
|
| 271 |
+
help="Path to save JSON report (optional)"
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--verbose",
|
| 275 |
+
"-v",
|
| 276 |
+
action="store_true",
|
| 277 |
+
help="Print progress information"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
|
| 282 |
+
# Check if data file exists
|
| 283 |
+
if not Path(args.data).exists():
|
| 284 |
+
print(f"Error: Data file not found: {args.data}")
|
| 285 |
+
return 1
|
| 286 |
+
|
| 287 |
+
# Run the check
|
| 288 |
+
overages, stats = check_benchmark_data(
|
| 289 |
+
args.data,
|
| 290 |
+
max_tokens=args.max_tokens,
|
| 291 |
+
verbose=args.verbose
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Print report
|
| 295 |
+
print_report(overages, stats)
|
| 296 |
+
|
| 297 |
+
# Save report if requested
|
| 298 |
+
if args.output:
|
| 299 |
+
save_report(overages, stats, args.output)
|
| 300 |
+
|
| 301 |
+
# Return exit code based on whether overages were found
|
| 302 |
+
return 1 if overages else 0
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
exit(main())
|
data_loader.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loader for Rabbinic Hebrew/Aramaic benchmark texts from Sefaria API.
|
| 3 |
+
|
| 4 |
+
Fetches parallel Hebrew/Aramaic + English text pairs across diverse categories.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import requests
|
| 15 |
+
import tiktoken
|
| 16 |
+
|
| 17 |
+
# Token limit for OpenAI embedding models (text-embedding-ada-002, text-embedding-3-*)
|
| 18 |
+
# Using cl100k_base encoding
|
| 19 |
+
MAX_EMBEDDING_TOKENS = 8192
|
| 20 |
+
_tokenizer = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_tokenizer() -> tiktoken.Encoding:
|
| 24 |
+
"""Get or create the tiktoken encoder (cached for performance)."""
|
| 25 |
+
global _tokenizer
|
| 26 |
+
if _tokenizer is None:
|
| 27 |
+
_tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 28 |
+
return _tokenizer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def count_tokens(text: str) -> int:
|
| 32 |
+
"""Count the number of tokens in a text string using OpenAI's tokenizer."""
|
| 33 |
+
return len(get_tokenizer().encode(text))
|
| 34 |
+
|
| 35 |
+
# Sefaria host - configurable via environment variable
|
| 36 |
+
# Default is the public Sefaria API
|
| 37 |
+
DEFAULT_SEFARIA_HOST = "https://www.sefaria.org"
|
| 38 |
+
SEFARIA_HOST = os.environ.get("SEFARIA_HOST", DEFAULT_SEFARIA_HOST)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def set_sefaria_host(host: str) -> None:
|
| 42 |
+
"""Set the Sefaria host URL (e.g., 'http://localhost:8000')."""
|
| 43 |
+
global SEFARIA_HOST
|
| 44 |
+
# Remove trailing slash if present
|
| 45 |
+
SEFARIA_HOST = host.rstrip("/")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_sefaria_host() -> str:
|
| 49 |
+
"""Get the current Sefaria host URL."""
|
| 50 |
+
return SEFARIA_HOST
|
| 51 |
+
|
| 52 |
+
# Text categories with confirmed English translations
|
| 53 |
+
BENCHMARK_TEXTS = {
|
| 54 |
+
"talmud_bavli": {
|
| 55 |
+
"category": "Talmud",
|
| 56 |
+
"language": "Aramaic/Hebrew",
|
| 57 |
+
"texts": [
|
| 58 |
+
"Berakhot",
|
| 59 |
+
"Pesachim",
|
| 60 |
+
"Yoma",
|
| 61 |
+
"Megillah",
|
| 62 |
+
"Chagigah",
|
| 63 |
+
"Ketubot",
|
| 64 |
+
"Gittin",
|
| 65 |
+
"Bava Metzia",
|
| 66 |
+
"Sanhedrin",
|
| 67 |
+
"Avodah Zarah",
|
| 68 |
+
"Chullin",
|
| 69 |
+
"Niddah",
|
| 70 |
+
],
|
| 71 |
+
},
|
| 72 |
+
"talmud_yerushalmi": {
|
| 73 |
+
"category": "Jerusalem Talmud",
|
| 74 |
+
"language": "Aramaic/Hebrew",
|
| 75 |
+
"texts": [
|
| 76 |
+
"Jerusalem Talmud Berakhot",
|
| 77 |
+
"Jerusalem Talmud Kilayim",
|
| 78 |
+
"Jerusalem Talmud Terumot",
|
| 79 |
+
"Jerusalem Talmud Shabbat",
|
| 80 |
+
"Jerusalem Talmud Shekalim",
|
| 81 |
+
"Jerusalem Talmud Sukkah",
|
| 82 |
+
"Jerusalem Talmud Sotah",
|
| 83 |
+
"Jerusalem Talmud Nedarim",
|
| 84 |
+
"Jerusalem Talmud Kiddushin",
|
| 85 |
+
"Jerusalem Talmud Bava Kamma",
|
| 86 |
+
"Jerusalem Talmud Sanhedrin",
|
| 87 |
+
"Jerusalem Talmud Avodah Zarah",
|
| 88 |
+
"Jerusalem Talmud Niddah",
|
| 89 |
+
],
|
| 90 |
+
},
|
| 91 |
+
"mishnah": {
|
| 92 |
+
"category": "Mishnah",
|
| 93 |
+
"language": "Rabbinic Hebrew",
|
| 94 |
+
"texts": [
|
| 95 |
+
"Mishnah Berakhot",
|
| 96 |
+
"Mishnah Peah",
|
| 97 |
+
"Mishnah Kilayim",
|
| 98 |
+
"Mishnah Shabbat",
|
| 99 |
+
"Mishnah Pesachim",
|
| 100 |
+
"Mishnah Sukkah",
|
| 101 |
+
"Mishnah Taanit",
|
| 102 |
+
"Mishnah Chagigah",
|
| 103 |
+
"Mishnah Yevamot",
|
| 104 |
+
"Mishnah Sotah",
|
| 105 |
+
"Mishnah Kiddushin",
|
| 106 |
+
"Mishnah Bava Kamma",
|
| 107 |
+
"Mishnah Sanhedrin",
|
| 108 |
+
"Mishnah Eduyot",
|
| 109 |
+
"Mishnah Avot",
|
| 110 |
+
"Mishnah Zevachim",
|
| 111 |
+
"Mishnah Chullin",
|
| 112 |
+
"Mishnah Tamid",
|
| 113 |
+
"Mishnah Kelim",
|
| 114 |
+
"Mishnah Parah",
|
| 115 |
+
"Mishnah Niddah",
|
| 116 |
+
],
|
| 117 |
+
},
|
| 118 |
+
"midrash_rabbah": {
|
| 119 |
+
"category": "Midrash Rabbah",
|
| 120 |
+
"language": "Hebrew/Aramaic",
|
| 121 |
+
"texts": [
|
| 122 |
+
"Bereishit Rabbah",
|
| 123 |
+
"Shemot Rabbah",
|
| 124 |
+
"Vayikra Rabbah",
|
| 125 |
+
"Bamidbar Rabbah",
|
| 126 |
+
"Devarim Rabbah",
|
| 127 |
+
"Shir HaShirim Rabbah",
|
| 128 |
+
"Ruth Rabbah",
|
| 129 |
+
"Eichah Rabbah",
|
| 130 |
+
"Kohelet Rabbah",
|
| 131 |
+
"Esther Rabbah",
|
| 132 |
+
],
|
| 133 |
+
},
|
| 134 |
+
"tanakh_commentary": {
|
| 135 |
+
"category": "Tanakh Commentary",
|
| 136 |
+
"language": "Hebrew",
|
| 137 |
+
"texts": [
|
| 138 |
+
"Rashi on Genesis",
|
| 139 |
+
"Rashi on Exodus",
|
| 140 |
+
"Rashi on Leviticus",
|
| 141 |
+
"Rashi on Numbers",
|
| 142 |
+
"Rashi on Deuteronomy",
|
| 143 |
+
"Ramban on Genesis",
|
| 144 |
+
"Ramban on Exodus",
|
| 145 |
+
"Ramban on Leviticus",
|
| 146 |
+
"Ramban on Numbers",
|
| 147 |
+
"Ramban on Deuteronomy",
|
| 148 |
+
"Radak on Genesis",
|
| 149 |
+
"Akeidat Yitzchak",
|
| 150 |
+
"Rabbeinu Behaye, Bereshit",
|
| 151 |
+
"Rabbeinu Behaye, Shemot",
|
| 152 |
+
"Rabbeinu Behaye, Vayikra",
|
| 153 |
+
"Rabbeinu Behaye, Bamidbar",
|
| 154 |
+
"Rabbeinu Behaye, Devarim",
|
| 155 |
+
],
|
| 156 |
+
},
|
| 157 |
+
"hasidic_kabbalistic": {
|
| 158 |
+
"category": "Hasidic/Kabbalistic",
|
| 159 |
+
"language": "Hebrew",
|
| 160 |
+
"texts": [
|
| 161 |
+
"Likutei Moharan",
|
| 162 |
+
"Tomer Devorah",
|
| 163 |
+
"Or Neerav, PART I",
|
| 164 |
+
"Or Neerav, PART II",
|
| 165 |
+
"Or Neerav, PART III",
|
| 166 |
+
"Shekel HaKodesh, On Abstinence",
|
| 167 |
+
"Shekel HaKodesh, On Wisdom",
|
| 168 |
+
"Kalach Pitchei Chokhmah",
|
| 169 |
+
],
|
| 170 |
+
},
|
| 171 |
+
"halacha": {
|
| 172 |
+
"category": "Halacha",
|
| 173 |
+
"language": "Hebrew",
|
| 174 |
+
"texts": [
|
| 175 |
+
"Sefer HaChinukh",
|
| 176 |
+
"Shev Shmateta, Introduction",
|
| 177 |
+
"Mishneh Torah, Human Dispositions",
|
| 178 |
+
"Sefer Yesodei HaTorah",
|
| 179 |
+
],
|
| 180 |
+
},
|
| 181 |
+
"philosophy": {
|
| 182 |
+
"category": "Philosophy",
|
| 183 |
+
"language": "Hebrew",
|
| 184 |
+
"texts": [
|
| 185 |
+
"Sefer HaIkkarim, Maamar 1",
|
| 186 |
+
"Sefer HaIkkarim, Maamar 2",
|
| 187 |
+
"Sefer HaIkkarim, Maamar 3",
|
| 188 |
+
"Guide for the Perplexed, Part 1",
|
| 189 |
+
"Guide for the Perplexed, Part 2",
|
| 190 |
+
"Guide for the Perplexed, Part 3",
|
| 191 |
+
],
|
| 192 |
+
},
|
| 193 |
+
"targum": {
|
| 194 |
+
"category": "Targum",
|
| 195 |
+
"language": "Aramaic",
|
| 196 |
+
"texts": [
|
| 197 |
+
"Aramaic Targum to Song of Songs",
|
| 198 |
+
],
|
| 199 |
+
},
|
| 200 |
+
"mussar": {
|
| 201 |
+
"category": "Mussar/Ethics",
|
| 202 |
+
"language": "Hebrew",
|
| 203 |
+
"texts": [
|
| 204 |
+
"Iggeret HaRamban",
|
| 205 |
+
"Shulchan Shel Arba",
|
| 206 |
+
"Chafetz Chaim",
|
| 207 |
+
"Yesod HaYirah, On Endurance",
|
| 208 |
+
"Yesod HaYirah, On Humility",
|
| 209 |
+
"Kav HaYashar",
|
| 210 |
+
],
|
| 211 |
+
},
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def strip_html(text: str) -> str:
|
| 216 |
+
"""
|
| 217 |
+
Remove HTML tags from text.
|
| 218 |
+
|
| 219 |
+
Some tags are dropped completely with their content:
|
| 220 |
+
- <sup class="footnote-marker">...</sup>
|
| 221 |
+
- <i class="footnote"...>...</i>
|
| 222 |
+
|
| 223 |
+
Other tags are stripped but their inner content is preserved.
|
| 224 |
+
"""
|
| 225 |
+
# First, remove footnote markers (simple, no nesting issues)
|
| 226 |
+
clean = re.sub(r'<sup[^>]*class="footnote-marker"[^>]*>.*?</sup>', '', text, flags=re.DOTALL)
|
| 227 |
+
|
| 228 |
+
# Remove footnotes with nested <i> tags - need to handle nesting
|
| 229 |
+
# Strategy: find footnote start, then count <i> and </i> to find matching close
|
| 230 |
+
clean = _remove_footnote_tags(clean)
|
| 231 |
+
|
| 232 |
+
# Then strip remaining HTML tags (keeping their content)
|
| 233 |
+
clean = re.sub(r"<[^>]+>", "", clean)
|
| 234 |
+
|
| 235 |
+
# Clean up extra whitespace
|
| 236 |
+
clean = re.sub(r"\s+", " ", clean).strip()
|
| 237 |
+
return clean
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _remove_footnote_tags(text: str) -> str:
|
| 241 |
+
"""Remove <i class="footnote"...>...</i> tags, handling nested <i> tags."""
|
| 242 |
+
result = []
|
| 243 |
+
i = 0
|
| 244 |
+
|
| 245 |
+
while i < len(text):
|
| 246 |
+
# Look for footnote opening tag
|
| 247 |
+
match = re.match(r'<i[^>]*class="footnote"[^>]*>', text[i:], flags=re.IGNORECASE)
|
| 248 |
+
if match:
|
| 249 |
+
# Found a footnote, now find the matching </i>
|
| 250 |
+
start = i + match.end()
|
| 251 |
+
depth = 1
|
| 252 |
+
j = start
|
| 253 |
+
|
| 254 |
+
while j < len(text) and depth > 0:
|
| 255 |
+
if text[j:j+3].lower() == '<i ' or text[j:j+3].lower() == '<i>':
|
| 256 |
+
depth += 1
|
| 257 |
+
j += 1
|
| 258 |
+
elif text[j:j+4].lower() == '</i>':
|
| 259 |
+
depth -= 1
|
| 260 |
+
if depth == 0:
|
| 261 |
+
# Skip past the closing </i>
|
| 262 |
+
j += 4
|
| 263 |
+
break
|
| 264 |
+
j += 1
|
| 265 |
+
else:
|
| 266 |
+
j += 1
|
| 267 |
+
|
| 268 |
+
# Skip the entire footnote (from i to j)
|
| 269 |
+
i = j
|
| 270 |
+
else:
|
| 271 |
+
result.append(text[i])
|
| 272 |
+
i += 1
|
| 273 |
+
|
| 274 |
+
return ''.join(result)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def extract_bold_only(text: str) -> str:
|
| 278 |
+
"""
|
| 279 |
+
Extract only content within <b>...</b> tags, for Talmud Bavli.
|
| 280 |
+
|
| 281 |
+
The Steinsaltz English has bold for actual translation and non-bold for
|
| 282 |
+
elucidation. We only want the translation.
|
| 283 |
+
|
| 284 |
+
Example:
|
| 285 |
+
"<b>The Rabbis say:</b> The time for... is <b>until midnight.</b>"
|
| 286 |
+
-> "The Rabbis say: until midnight."
|
| 287 |
+
"""
|
| 288 |
+
# Find all content within <b>...</b> tags
|
| 289 |
+
bold_parts = re.findall(r'<b>(.*?)</b>', text, flags=re.DOTALL)
|
| 290 |
+
|
| 291 |
+
if not bold_parts:
|
| 292 |
+
# No bold tags found, fall back to regular strip
|
| 293 |
+
return strip_html(text)
|
| 294 |
+
|
| 295 |
+
# Strip any nested HTML from each bold part and join with spaces
|
| 296 |
+
cleaned_parts = [strip_html(part) for part in bold_parts]
|
| 297 |
+
|
| 298 |
+
# Join parts, ensuring proper spacing
|
| 299 |
+
result = ' '.join(cleaned_parts)
|
| 300 |
+
|
| 301 |
+
# Clean up extra whitespace
|
| 302 |
+
result = re.sub(r"\s+", " ", result).strip()
|
| 303 |
+
return result
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_text_from_sefaria(ref: str, retries: int = 3) -> Optional[dict]:
|
| 307 |
+
"""
|
| 308 |
+
Fetch a text from Sefaria API.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
ref: Sefaria reference string (e.g., "Berakhot.2a")
|
| 312 |
+
retries: Number of retry attempts
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Dict with 'he' (Hebrew/Aramaic) and 'en' (English) texts, or None if failed/error
|
| 316 |
+
"""
|
| 317 |
+
url = f"{SEFARIA_HOST}/api/texts/{ref}"
|
| 318 |
+
params = {"context": 0}
|
| 319 |
+
|
| 320 |
+
for attempt in range(retries):
|
| 321 |
+
try:
|
| 322 |
+
response = requests.get(url, params=params, timeout=30)
|
| 323 |
+
if response.status_code == 200:
|
| 324 |
+
data = response.json()
|
| 325 |
+
# Check if response contains an error
|
| 326 |
+
if "error" in data:
|
| 327 |
+
return None
|
| 328 |
+
return data
|
| 329 |
+
elif response.status_code == 429:
|
| 330 |
+
# Rate limited, wait and retry
|
| 331 |
+
time.sleep(2 ** attempt)
|
| 332 |
+
else:
|
| 333 |
+
return None
|
| 334 |
+
except requests.RequestException:
|
| 335 |
+
if attempt < retries - 1:
|
| 336 |
+
time.sleep(1)
|
| 337 |
+
continue
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_index_from_sefaria(title: str) -> Optional[dict]:
|
| 342 |
+
"""
|
| 343 |
+
Get index/structure information for a text.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
title: The title of the text
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Index data or None if failed or text not found
|
| 350 |
+
"""
|
| 351 |
+
url = f"{SEFARIA_HOST}/api/index/{title}"
|
| 352 |
+
try:
|
| 353 |
+
response = requests.get(url, timeout=30)
|
| 354 |
+
if response.status_code == 200:
|
| 355 |
+
data = response.json()
|
| 356 |
+
# Check if response contains an error
|
| 357 |
+
if "error" in data:
|
| 358 |
+
return None
|
| 359 |
+
return data
|
| 360 |
+
except requests.RequestException:
|
| 361 |
+
pass
|
| 362 |
+
return None
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def extract_parallel_segments(data: dict, ref: str, category: str = "") -> list[dict]:
|
| 366 |
+
"""
|
| 367 |
+
Extract parallel Hebrew/English segments from API response.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
data: API response data
|
| 371 |
+
ref: The reference string
|
| 372 |
+
category: Category name (used for special handling, e.g., "Talmud")
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
List of dicts with 'ref', 'he', 'en' keys
|
| 376 |
+
"""
|
| 377 |
+
segments = []
|
| 378 |
+
|
| 379 |
+
he_text = data.get("he", [])
|
| 380 |
+
en_text = data.get("text", [])
|
| 381 |
+
|
| 382 |
+
# Handle nested arrays (common in Talmud)
|
| 383 |
+
if he_text and isinstance(he_text, list):
|
| 384 |
+
# Flatten if nested
|
| 385 |
+
if he_text and isinstance(he_text[0], list):
|
| 386 |
+
he_flat = []
|
| 387 |
+
en_flat = []
|
| 388 |
+
for i, (he_seg, en_seg) in enumerate(zip(he_text, en_text)):
|
| 389 |
+
if isinstance(he_seg, list):
|
| 390 |
+
he_flat.extend(he_seg)
|
| 391 |
+
en_flat.extend(en_seg if isinstance(en_seg, list) else [en_seg])
|
| 392 |
+
else:
|
| 393 |
+
he_flat.append(he_seg)
|
| 394 |
+
en_flat.append(en_seg)
|
| 395 |
+
he_text = he_flat
|
| 396 |
+
en_text = en_flat
|
| 397 |
+
|
| 398 |
+
# Handle single string responses
|
| 399 |
+
if isinstance(he_text, str):
|
| 400 |
+
he_text = [he_text]
|
| 401 |
+
if isinstance(en_text, str):
|
| 402 |
+
en_text = [en_text]
|
| 403 |
+
|
| 404 |
+
# For Talmud Bavli, extract only bold text (actual translation, not elucidation)
|
| 405 |
+
is_bavli = category == "Talmud"
|
| 406 |
+
|
| 407 |
+
# Pair up segments
|
| 408 |
+
for i, (he, en) in enumerate(zip(he_text, en_text)):
|
| 409 |
+
if he and en:
|
| 410 |
+
he_clean = strip_html(str(he)) if he else ""
|
| 411 |
+
# Use bold-only extraction for Bavli English
|
| 412 |
+
if is_bavli:
|
| 413 |
+
en_clean = extract_bold_only(str(en)) if en else ""
|
| 414 |
+
else:
|
| 415 |
+
en_clean = strip_html(str(en)) if en else ""
|
| 416 |
+
|
| 417 |
+
# Skip empty or very short segments
|
| 418 |
+
if len(he_clean) > 10 and len(en_clean) > 10:
|
| 419 |
+
# Check token limits for OpenAI embedding models
|
| 420 |
+
he_tokens = count_tokens(he_clean)
|
| 421 |
+
en_tokens = count_tokens(en_clean)
|
| 422 |
+
|
| 423 |
+
if he_tokens > MAX_EMBEDDING_TOKENS:
|
| 424 |
+
print(f" Skipping {ref}:{i+1} - Hebrew text exceeds token limit ({he_tokens} > {MAX_EMBEDDING_TOKENS})")
|
| 425 |
+
continue
|
| 426 |
+
if en_tokens > MAX_EMBEDDING_TOKENS:
|
| 427 |
+
print(f" Skipping {ref}:{i+1} - English text exceeds token limit ({en_tokens} > {MAX_EMBEDDING_TOKENS})")
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
segments.append({
|
| 431 |
+
"ref": f"{ref}:{i+1}" if ":" not in ref else ref,
|
| 432 |
+
"he": he_clean,
|
| 433 |
+
"en": en_clean,
|
| 434 |
+
})
|
| 435 |
+
|
| 436 |
+
return segments
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def fetch_text_pairs(
|
| 440 |
+
text_title: str,
|
| 441 |
+
category: str,
|
| 442 |
+
max_segments: int = 500,
|
| 443 |
+
delay: float = 0.5
|
| 444 |
+
) -> list[dict]:
|
| 445 |
+
"""
|
| 446 |
+
Fetch parallel text pairs for a given text.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
text_title: Title of the text to fetch
|
| 450 |
+
category: Category name for metadata
|
| 451 |
+
max_segments: Maximum segments to fetch per text
|
| 452 |
+
delay: Delay between API calls (rate limiting)
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
List of segment dicts with ref, he, en, category
|
| 456 |
+
"""
|
| 457 |
+
pairs = []
|
| 458 |
+
|
| 459 |
+
# Get text index to understand structure
|
| 460 |
+
index = get_index_from_sefaria(text_title)
|
| 461 |
+
if not index:
|
| 462 |
+
print(f" Could not get index for {text_title}")
|
| 463 |
+
return pairs
|
| 464 |
+
|
| 465 |
+
# Determine refs to fetch based on text structure
|
| 466 |
+
schema = index.get("schema", {})
|
| 467 |
+
|
| 468 |
+
# For simple texts, just fetch the whole thing
|
| 469 |
+
if schema.get("nodeType") == "JaggedArrayNode":
|
| 470 |
+
depth = schema.get("depth", 2)
|
| 471 |
+
address_types = schema.get("addressTypes", [])
|
| 472 |
+
|
| 473 |
+
# Check if this uses Talmud daf notation (2a, 2b, etc.)
|
| 474 |
+
uses_talmud_daf = address_types and address_types[0] == "Talmud"
|
| 475 |
+
|
| 476 |
+
if uses_talmud_daf:
|
| 477 |
+
# Talmud-style structure with daf notation (e.g., Berakhot.2a)
|
| 478 |
+
# Start from daf 3 for Jerusalem Talmud to avoid overlap with Bavli
|
| 479 |
+
start_daf = 3 if category == "Jerusalem Talmud" else 2
|
| 480 |
+
# Fetch daf by daf
|
| 481 |
+
done = False
|
| 482 |
+
for daf_num in range(start_daf, 200):
|
| 483 |
+
if len(pairs) >= max_segments or done:
|
| 484 |
+
break
|
| 485 |
+
|
| 486 |
+
for side in ["a", "b"]:
|
| 487 |
+
if len(pairs) >= max_segments:
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
ref = f"{text_title}.{daf_num}{side}"
|
| 491 |
+
data = get_text_from_sefaria(ref)
|
| 492 |
+
|
| 493 |
+
# None means API error (daf doesn't exist)
|
| 494 |
+
if data is None:
|
| 495 |
+
if side == "a":
|
| 496 |
+
done = True # Daf doesn't exist, we're done with tractate
|
| 497 |
+
break
|
| 498 |
+
|
| 499 |
+
if not data.get("he"):
|
| 500 |
+
continue # Empty side, try next
|
| 501 |
+
|
| 502 |
+
segments = extract_parallel_segments(data, ref, category)
|
| 503 |
+
for seg in segments:
|
| 504 |
+
seg["category"] = category
|
| 505 |
+
pairs.extend(segments)
|
| 506 |
+
|
| 507 |
+
time.sleep(delay)
|
| 508 |
+
|
| 509 |
+
elif depth == 1:
|
| 510 |
+
# Single-level structure (e.g., Iggeret HaRamban - just paragraphs)
|
| 511 |
+
# Fetch the whole text at once
|
| 512 |
+
data = get_text_from_sefaria(text_title)
|
| 513 |
+
if data and data.get("he"):
|
| 514 |
+
segments = extract_parallel_segments(data, text_title, category)
|
| 515 |
+
for seg in segments:
|
| 516 |
+
seg["category"] = category
|
| 517 |
+
pairs.extend(segments)
|
| 518 |
+
|
| 519 |
+
elif depth == 2:
|
| 520 |
+
# Two-level structure (e.g., Mishnah chapter:verse)
|
| 521 |
+
# Start from chapter 2 for Mishnah to avoid overlap with Talmud
|
| 522 |
+
start_chapter = 2 if category == "Mishnah" else 1
|
| 523 |
+
consecutive_empty = 0
|
| 524 |
+
# Fetch chapter by chapter
|
| 525 |
+
for chapter in range(start_chapter, 200): # Reasonable upper bound
|
| 526 |
+
if len(pairs) >= max_segments:
|
| 527 |
+
break
|
| 528 |
+
|
| 529 |
+
ref = f"{text_title}.{chapter}"
|
| 530 |
+
data = get_text_from_sefaria(ref)
|
| 531 |
+
|
| 532 |
+
# None means API error (ref doesn't exist)
|
| 533 |
+
if data is None:
|
| 534 |
+
break
|
| 535 |
+
|
| 536 |
+
# Empty array means chapter exists but has no content
|
| 537 |
+
if not data.get("he"):
|
| 538 |
+
consecutive_empty += 1
|
| 539 |
+
if consecutive_empty >= 5:
|
| 540 |
+
break # Probably past end of book
|
| 541 |
+
time.sleep(delay)
|
| 542 |
+
continue
|
| 543 |
+
|
| 544 |
+
consecutive_empty = 0
|
| 545 |
+
segments = extract_parallel_segments(data, ref, category)
|
| 546 |
+
for seg in segments:
|
| 547 |
+
seg["category"] = category
|
| 548 |
+
pairs.extend(segments)
|
| 549 |
+
|
| 550 |
+
time.sleep(delay)
|
| 551 |
+
|
| 552 |
+
elif depth >= 3:
|
| 553 |
+
# Three+ level structure (e.g., commentary chapter:verse:comment)
|
| 554 |
+
# Fetch chapter.verse by chapter.verse
|
| 555 |
+
# For Jerusalem Talmud, start from 1.3 to avoid overlap with Bavli
|
| 556 |
+
start_verse = 3 if category == "Jerusalem Talmud" else 1
|
| 557 |
+
consecutive_empty_chapters = 0
|
| 558 |
+
for chapter in range(1, 200):
|
| 559 |
+
if len(pairs) >= max_segments:
|
| 560 |
+
break
|
| 561 |
+
|
| 562 |
+
chapter_had_content = False
|
| 563 |
+
# Use start_verse only for first chapter
|
| 564 |
+
first_verse = start_verse if chapter == 1 else 1
|
| 565 |
+
for verse in range(first_verse, 100):
|
| 566 |
+
if len(pairs) >= max_segments:
|
| 567 |
+
break
|
| 568 |
+
|
| 569 |
+
ref = f"{text_title}.{chapter}.{verse}"
|
| 570 |
+
data = get_text_from_sefaria(ref)
|
| 571 |
+
|
| 572 |
+
# None means API error (ref doesn't exist)
|
| 573 |
+
if data is None:
|
| 574 |
+
break # No more verses in this chapter
|
| 575 |
+
|
| 576 |
+
# Empty array means verse exists but has no content
|
| 577 |
+
if not data.get("he"):
|
| 578 |
+
continue
|
| 579 |
+
|
| 580 |
+
chapter_had_content = True
|
| 581 |
+
segments = extract_parallel_segments(data, ref, category)
|
| 582 |
+
for seg in segments:
|
| 583 |
+
seg["category"] = category
|
| 584 |
+
pairs.extend(segments)
|
| 585 |
+
|
| 586 |
+
time.sleep(delay)
|
| 587 |
+
|
| 588 |
+
if not chapter_had_content:
|
| 589 |
+
consecutive_empty_chapters += 1
|
| 590 |
+
if consecutive_empty_chapters >= 5:
|
| 591 |
+
break # Probably past end of book
|
| 592 |
+
else:
|
| 593 |
+
consecutive_empty_chapters = 0
|
| 594 |
+
|
| 595 |
+
else:
|
| 596 |
+
# Complex structure (SchemaNode) - try different ref patterns
|
| 597 |
+
# First try simple numeric refs (works for Sefer HaChinukh style)
|
| 598 |
+
consecutive_empty = 0
|
| 599 |
+
for section in range(1, 1000):
|
| 600 |
+
if len(pairs) >= max_segments:
|
| 601 |
+
break
|
| 602 |
+
|
| 603 |
+
ref = f"{text_title}.{section}"
|
| 604 |
+
data = get_text_from_sefaria(ref)
|
| 605 |
+
|
| 606 |
+
if data is None:
|
| 607 |
+
break
|
| 608 |
+
|
| 609 |
+
if not data.get("he"):
|
| 610 |
+
consecutive_empty += 1
|
| 611 |
+
if consecutive_empty >= 5:
|
| 612 |
+
break
|
| 613 |
+
time.sleep(delay)
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
consecutive_empty = 0
|
| 617 |
+
segments = extract_parallel_segments(data, ref, category)
|
| 618 |
+
for seg in segments:
|
| 619 |
+
seg["category"] = category
|
| 620 |
+
pairs.extend(segments)
|
| 621 |
+
|
| 622 |
+
time.sleep(delay)
|
| 623 |
+
|
| 624 |
+
# If we haven't reached max_segments, try chapter.verse style refs (commentary pattern)
|
| 625 |
+
if len(pairs) < max_segments:
|
| 626 |
+
consecutive_empty = 0
|
| 627 |
+
for chapter in range(1, 100):
|
| 628 |
+
if len(pairs) >= max_segments:
|
| 629 |
+
break
|
| 630 |
+
|
| 631 |
+
chapter_had_content = False
|
| 632 |
+
for verse in range(1, 50):
|
| 633 |
+
if len(pairs) >= max_segments:
|
| 634 |
+
break
|
| 635 |
+
|
| 636 |
+
ref = f"{text_title}.{chapter}.{verse}"
|
| 637 |
+
data = get_text_from_sefaria(ref)
|
| 638 |
+
|
| 639 |
+
if data is None:
|
| 640 |
+
break # This verse doesn't exist, try next chapter
|
| 641 |
+
|
| 642 |
+
if data.get("he"):
|
| 643 |
+
chapter_had_content = True
|
| 644 |
+
consecutive_empty = 0
|
| 645 |
+
segments = extract_parallel_segments(data, ref, category)
|
| 646 |
+
for seg in segments:
|
| 647 |
+
seg["category"] = category
|
| 648 |
+
pairs.extend(segments)
|
| 649 |
+
|
| 650 |
+
time.sleep(delay)
|
| 651 |
+
|
| 652 |
+
if not chapter_had_content:
|
| 653 |
+
consecutive_empty += 1
|
| 654 |
+
if consecutive_empty >= 5:
|
| 655 |
+
break
|
| 656 |
+
|
| 657 |
+
return pairs[:max_segments]
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def build_benchmark_dataset(
|
| 661 |
+
output_path: str = "benchmark_data/benchmark.json",
|
| 662 |
+
segments_per_text: int = 200,
|
| 663 |
+
total_target: int = 10000,
|
| 664 |
+
) -> list[dict]:
|
| 665 |
+
"""
|
| 666 |
+
Build the full benchmark dataset from all configured texts.
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
output_path: Path to save the benchmark JSON
|
| 670 |
+
segments_per_text: Target segments per text
|
| 671 |
+
total_target: Overall target segment count
|
| 672 |
+
|
| 673 |
+
Returns:
|
| 674 |
+
List of all benchmark pairs
|
| 675 |
+
"""
|
| 676 |
+
all_pairs = []
|
| 677 |
+
|
| 678 |
+
for category_key, category_info in BENCHMARK_TEXTS.items():
|
| 679 |
+
category_name = category_info["category"]
|
| 680 |
+
texts = category_info["texts"]
|
| 681 |
+
|
| 682 |
+
print(f"\n{'='*60}")
|
| 683 |
+
print(f"Processing category: {category_name}")
|
| 684 |
+
print(f"{'='*60}")
|
| 685 |
+
|
| 686 |
+
for text_title in texts:
|
| 687 |
+
if len(all_pairs) >= total_target:
|
| 688 |
+
break
|
| 689 |
+
|
| 690 |
+
print(f"\nFetching: {text_title}")
|
| 691 |
+
|
| 692 |
+
pairs = fetch_text_pairs(
|
| 693 |
+
text_title,
|
| 694 |
+
category_name,
|
| 695 |
+
max_segments=segments_per_text,
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
print(f" Got {len(pairs)} pairs")
|
| 699 |
+
all_pairs.extend(pairs)
|
| 700 |
+
|
| 701 |
+
if len(all_pairs) >= total_target:
|
| 702 |
+
break
|
| 703 |
+
|
| 704 |
+
# Save to file
|
| 705 |
+
output_file = Path(output_path)
|
| 706 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 707 |
+
|
| 708 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 709 |
+
json.dump(all_pairs, f, ensure_ascii=False, indent=2)
|
| 710 |
+
|
| 711 |
+
print(f"\n{'='*60}")
|
| 712 |
+
print(f"Total pairs collected: {len(all_pairs)}")
|
| 713 |
+
print(f"Saved to: {output_path}")
|
| 714 |
+
|
| 715 |
+
# Save stats to markdown file
|
| 716 |
+
stats = get_benchmark_stats(all_pairs)
|
| 717 |
+
save_stats_markdown(stats, output_path)
|
| 718 |
+
|
| 719 |
+
return all_pairs
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def load_benchmark_dataset(path: str = "benchmark_data/benchmark.json") -> list[dict]:
|
| 723 |
+
"""
|
| 724 |
+
Load the pre-cached benchmark dataset.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
path: Path to the benchmark JSON file
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
List of benchmark pairs
|
| 731 |
+
"""
|
| 732 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 733 |
+
return json.load(f)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
def get_benchmark_stats(pairs: list[dict]) -> dict:
|
| 737 |
+
"""
|
| 738 |
+
Get statistics about the benchmark dataset.
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
pairs: List of benchmark pairs
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
Dict with category counts and other stats
|
| 745 |
+
"""
|
| 746 |
+
from collections import Counter
|
| 747 |
+
|
| 748 |
+
categories = Counter(p["category"] for p in pairs)
|
| 749 |
+
|
| 750 |
+
he_lengths = [len(p["he"]) for p in pairs]
|
| 751 |
+
en_lengths = [len(p["en"]) for p in pairs]
|
| 752 |
+
|
| 753 |
+
return {
|
| 754 |
+
"total_pairs": len(pairs),
|
| 755 |
+
"categories": dict(categories),
|
| 756 |
+
"avg_he_length": sum(he_lengths) / len(he_lengths) if he_lengths else 0,
|
| 757 |
+
"avg_en_length": sum(en_lengths) / len(en_lengths) if en_lengths else 0,
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def save_stats_markdown(stats: dict, data_path: str) -> str:
|
| 762 |
+
"""
|
| 763 |
+
Save benchmark statistics to a markdown file alongside the data.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
stats: Statistics dict from get_benchmark_stats()
|
| 767 |
+
data_path: Path to the data file (used to derive stats file path)
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
Path to the saved markdown file
|
| 771 |
+
"""
|
| 772 |
+
from datetime import datetime
|
| 773 |
+
|
| 774 |
+
# Derive markdown path from data path
|
| 775 |
+
data_file = Path(data_path)
|
| 776 |
+
stats_path = data_file.with_suffix(".stats.md")
|
| 777 |
+
|
| 778 |
+
# Build markdown content
|
| 779 |
+
lines = [
|
| 780 |
+
"# Benchmark Dataset Statistics",
|
| 781 |
+
"",
|
| 782 |
+
f"**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
| 783 |
+
"",
|
| 784 |
+
"## Summary",
|
| 785 |
+
"",
|
| 786 |
+
f"- **Total pairs:** {stats['total_pairs']:,}",
|
| 787 |
+
f"- **Average Hebrew length:** {stats['avg_he_length']:.0f} chars",
|
| 788 |
+
f"- **Average English length:** {stats['avg_en_length']:.0f} chars",
|
| 789 |
+
"",
|
| 790 |
+
"## Category Breakdown",
|
| 791 |
+
"",
|
| 792 |
+
"| Category | Count |",
|
| 793 |
+
"|----------|-------|",
|
| 794 |
+
]
|
| 795 |
+
|
| 796 |
+
# Sort categories by count (descending)
|
| 797 |
+
sorted_categories = sorted(
|
| 798 |
+
stats["categories"].items(),
|
| 799 |
+
key=lambda x: x[1],
|
| 800 |
+
reverse=True
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
for category, count in sorted_categories:
|
| 804 |
+
lines.append(f"| {category} | {count:,} |")
|
| 805 |
+
|
| 806 |
+
lines.append("")
|
| 807 |
+
|
| 808 |
+
# Write to file
|
| 809 |
+
with open(stats_path, "w", encoding="utf-8") as f:
|
| 810 |
+
f.write("\n".join(lines))
|
| 811 |
+
|
| 812 |
+
print(f"Stats saved to: {stats_path}")
|
| 813 |
+
return str(stats_path)
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
if __name__ == "__main__":
|
| 817 |
+
# Build the benchmark dataset
|
| 818 |
+
print("Building Rabbinic Hebrew/Aramaic benchmark dataset...")
|
| 819 |
+
pairs = build_benchmark_dataset()
|
| 820 |
+
|
| 821 |
+
# Print stats
|
| 822 |
+
stats = get_benchmark_stats(pairs)
|
| 823 |
+
print(f"\nDataset Statistics:")
|
| 824 |
+
print(f" Total pairs: {stats['total_pairs']}")
|
| 825 |
+
print(f" Categories: {stats['categories']}")
|
| 826 |
+
print(f" Avg Hebrew length: {stats['avg_he_length']:.0f} chars")
|
| 827 |
+
print(f" Avg English length: {stats['avg_en_length']:.0f} chars")
|
| 828 |
+
|
evaluation.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cross-lingual retrieval evaluation for Rabbinic embedding benchmark.
|
| 3 |
+
|
| 4 |
+
Computes retrieval metrics to measure how well embedding models align
|
| 5 |
+
Hebrew/Aramaic source texts with their English translations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class EvaluationResults:
|
| 15 |
+
"""Container for evaluation results."""
|
| 16 |
+
|
| 17 |
+
model_id: str
|
| 18 |
+
model_name: str
|
| 19 |
+
|
| 20 |
+
# Core retrieval metrics
|
| 21 |
+
recall_at_1: float
|
| 22 |
+
recall_at_5: float
|
| 23 |
+
recall_at_10: float
|
| 24 |
+
mrr: float # Mean Reciprocal Rank
|
| 25 |
+
|
| 26 |
+
# Additional metrics
|
| 27 |
+
bitext_accuracy: float # True pair vs random pair classification
|
| 28 |
+
avg_true_pair_similarity: float
|
| 29 |
+
avg_random_pair_similarity: float
|
| 30 |
+
|
| 31 |
+
# Metadata
|
| 32 |
+
num_pairs: int
|
| 33 |
+
categories: dict[str, int]
|
| 34 |
+
|
| 35 |
+
def to_dict(self) -> dict:
|
| 36 |
+
"""Convert to dictionary for JSON serialization."""
|
| 37 |
+
return {
|
| 38 |
+
"model_id": self.model_id,
|
| 39 |
+
"model_name": self.model_name,
|
| 40 |
+
"recall_at_1": self.recall_at_1,
|
| 41 |
+
"recall_at_5": self.recall_at_5,
|
| 42 |
+
"recall_at_10": self.recall_at_10,
|
| 43 |
+
"mrr": self.mrr,
|
| 44 |
+
"bitext_accuracy": self.bitext_accuracy,
|
| 45 |
+
"avg_true_pair_similarity": self.avg_true_pair_similarity,
|
| 46 |
+
"avg_random_pair_similarity": self.avg_random_pair_similarity,
|
| 47 |
+
"num_pairs": self.num_pairs,
|
| 48 |
+
"categories": self.categories,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_dict(cls, data: dict) -> "EvaluationResults":
|
| 53 |
+
"""Create from dictionary."""
|
| 54 |
+
return cls(**data)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_similarity_matrix(
|
| 58 |
+
query_embeddings: np.ndarray,
|
| 59 |
+
passage_embeddings: np.ndarray,
|
| 60 |
+
) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Compute cosine similarity matrix between queries and passages.
|
| 63 |
+
|
| 64 |
+
Assumes embeddings are already L2-normalized.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
query_embeddings: (N, D) array of query embeddings
|
| 68 |
+
passage_embeddings: (M, D) array of passage embeddings
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
(N, M) similarity matrix
|
| 72 |
+
"""
|
| 73 |
+
return np.dot(query_embeddings, passage_embeddings.T)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def compute_retrieval_metrics(
|
| 77 |
+
similarity_matrix: np.ndarray,
|
| 78 |
+
k_values: list[int] = [1, 5, 10],
|
| 79 |
+
) -> dict[str, float]:
|
| 80 |
+
"""
|
| 81 |
+
Compute retrieval metrics from similarity matrix.
|
| 82 |
+
|
| 83 |
+
Assumes the correct match for query i is passage i (diagonal).
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
similarity_matrix: (N, N) similarity matrix where diagonal is true matches
|
| 87 |
+
k_values: List of k values for Recall@k
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Dict with recall@k and mrr values
|
| 91 |
+
"""
|
| 92 |
+
n = similarity_matrix.shape[0]
|
| 93 |
+
|
| 94 |
+
# Get rankings for each query
|
| 95 |
+
# Negate to sort descending (highest similarity first)
|
| 96 |
+
rankings = np.argsort(-similarity_matrix, axis=1)
|
| 97 |
+
|
| 98 |
+
# Find rank of true match (diagonal) for each query
|
| 99 |
+
true_ranks = np.zeros(n, dtype=int)
|
| 100 |
+
for i in range(n):
|
| 101 |
+
# Find position of index i in the ranking for query i
|
| 102 |
+
true_ranks[i] = np.where(rankings[i] == i)[0][0]
|
| 103 |
+
|
| 104 |
+
results = {}
|
| 105 |
+
|
| 106 |
+
# Recall@k: fraction where true match is in top k
|
| 107 |
+
for k in k_values:
|
| 108 |
+
recall = np.mean(true_ranks < k)
|
| 109 |
+
results[f"recall_at_{k}"] = float(recall)
|
| 110 |
+
|
| 111 |
+
# MRR: Mean Reciprocal Rank
|
| 112 |
+
reciprocal_ranks = 1.0 / (true_ranks + 1) # +1 because ranks are 0-indexed
|
| 113 |
+
results["mrr"] = float(np.mean(reciprocal_ranks))
|
| 114 |
+
|
| 115 |
+
return results
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def compute_bitext_accuracy(
|
| 119 |
+
similarity_matrix: np.ndarray,
|
| 120 |
+
num_negatives: int = 10,
|
| 121 |
+
) -> tuple[float, float, float]:
|
| 122 |
+
"""
|
| 123 |
+
Compute bitext mining accuracy.
|
| 124 |
+
|
| 125 |
+
For each true pair, sample random negative pairs and check if the model
|
| 126 |
+
correctly ranks the true pair higher.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
similarity_matrix: (N, N) similarity matrix
|
| 130 |
+
num_negatives: Number of negative samples per true pair
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Tuple of (accuracy, avg_true_sim, avg_random_sim)
|
| 134 |
+
"""
|
| 135 |
+
n = similarity_matrix.shape[0]
|
| 136 |
+
|
| 137 |
+
# True pair similarities (diagonal)
|
| 138 |
+
true_similarities = np.diag(similarity_matrix)
|
| 139 |
+
|
| 140 |
+
# Sample random negative pairs
|
| 141 |
+
correct = 0
|
| 142 |
+
total = 0
|
| 143 |
+
random_sims = []
|
| 144 |
+
|
| 145 |
+
rng = np.random.default_rng(42)
|
| 146 |
+
|
| 147 |
+
for i in range(n):
|
| 148 |
+
true_sim = true_similarities[i]
|
| 149 |
+
|
| 150 |
+
# Sample random passage indices (not the true match)
|
| 151 |
+
neg_indices = rng.choice(
|
| 152 |
+
[j for j in range(n) if j != i],
|
| 153 |
+
size=min(num_negatives, n - 1),
|
| 154 |
+
replace=False,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
for j in neg_indices:
|
| 158 |
+
neg_sim = similarity_matrix[i, j]
|
| 159 |
+
random_sims.append(neg_sim)
|
| 160 |
+
|
| 161 |
+
if true_sim > neg_sim:
|
| 162 |
+
correct += 1
|
| 163 |
+
total += 1
|
| 164 |
+
|
| 165 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 166 |
+
avg_true = float(np.mean(true_similarities))
|
| 167 |
+
avg_random = float(np.mean(random_sims)) if random_sims else 0.0
|
| 168 |
+
|
| 169 |
+
return accuracy, avg_true, avg_random
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def evaluate_model(
|
| 173 |
+
model,
|
| 174 |
+
benchmark_pairs: list[dict],
|
| 175 |
+
batch_size: int = 32,
|
| 176 |
+
max_pairs: Optional[int] = None,
|
| 177 |
+
progress_callback=None,
|
| 178 |
+
) -> EvaluationResults:
|
| 179 |
+
"""
|
| 180 |
+
Run full evaluation of a model on the benchmark.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
model: EmbeddingModel instance
|
| 184 |
+
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
|
| 185 |
+
batch_size: Batch size for encoding
|
| 186 |
+
max_pairs: Maximum pairs to evaluate (for faster testing)
|
| 187 |
+
progress_callback: Optional callback for progress updates
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
EvaluationResults with all metrics
|
| 191 |
+
"""
|
| 192 |
+
# Use streaming version and return final result
|
| 193 |
+
result = None
|
| 194 |
+
for item in evaluate_model_streaming(model, benchmark_pairs, batch_size, max_pairs):
|
| 195 |
+
if isinstance(item, str):
|
| 196 |
+
if progress_callback:
|
| 197 |
+
progress_callback(0.5, item)
|
| 198 |
+
else:
|
| 199 |
+
result = item
|
| 200 |
+
return result
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def evaluate_model_streaming(
|
| 204 |
+
model,
|
| 205 |
+
benchmark_pairs: list[dict],
|
| 206 |
+
batch_size: int = 32,
|
| 207 |
+
max_pairs: Optional[int] = None,
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Run evaluation with streaming progress updates.
|
| 211 |
+
|
| 212 |
+
Yields progress strings during encoding, then yields final EvaluationResults.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
model: EmbeddingModel instance
|
| 216 |
+
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
|
| 217 |
+
batch_size: Batch size for encoding
|
| 218 |
+
max_pairs: Maximum pairs to evaluate (for faster testing)
|
| 219 |
+
|
| 220 |
+
Yields:
|
| 221 |
+
Progress strings, then final EvaluationResults
|
| 222 |
+
"""
|
| 223 |
+
from collections import Counter
|
| 224 |
+
|
| 225 |
+
# Optionally limit pairs
|
| 226 |
+
if max_pairs and len(benchmark_pairs) > max_pairs:
|
| 227 |
+
benchmark_pairs = benchmark_pairs[:max_pairs]
|
| 228 |
+
|
| 229 |
+
# Extract texts
|
| 230 |
+
he_texts = [p["he"] for p in benchmark_pairs]
|
| 231 |
+
en_texts = [p["en"] for p in benchmark_pairs]
|
| 232 |
+
categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs)
|
| 233 |
+
n_total = len(he_texts)
|
| 234 |
+
|
| 235 |
+
# Encode Hebrew texts in batches with progress
|
| 236 |
+
yield f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}"
|
| 237 |
+
he_embeddings_list = []
|
| 238 |
+
for i in range(0, len(he_texts), batch_size):
|
| 239 |
+
batch = he_texts[i:i + batch_size]
|
| 240 |
+
batch_emb = model.encode(
|
| 241 |
+
batch,
|
| 242 |
+
is_query=True,
|
| 243 |
+
batch_size=batch_size,
|
| 244 |
+
show_progress=False,
|
| 245 |
+
)
|
| 246 |
+
he_embeddings_list.append(batch_emb)
|
| 247 |
+
done = min(i + batch_size, len(he_texts))
|
| 248 |
+
yield f"⏳ Encoding Hebrew/Aramaic texts: {done:,}/{n_total:,}"
|
| 249 |
+
|
| 250 |
+
he_embeddings = np.vstack(he_embeddings_list)
|
| 251 |
+
|
| 252 |
+
# Encode English texts in batches with progress
|
| 253 |
+
yield f"⏳ Encoding English texts: 0/{n_total:,}"
|
| 254 |
+
en_embeddings_list = []
|
| 255 |
+
for i in range(0, len(en_texts), batch_size):
|
| 256 |
+
batch = en_texts[i:i + batch_size]
|
| 257 |
+
batch_emb = model.encode(
|
| 258 |
+
batch,
|
| 259 |
+
is_query=False,
|
| 260 |
+
batch_size=batch_size,
|
| 261 |
+
show_progress=False,
|
| 262 |
+
)
|
| 263 |
+
en_embeddings_list.append(batch_emb)
|
| 264 |
+
done = min(i + batch_size, len(en_texts))
|
| 265 |
+
yield f"⏳ Encoding English texts: {done:,}/{n_total:,}"
|
| 266 |
+
|
| 267 |
+
en_embeddings = np.vstack(en_embeddings_list)
|
| 268 |
+
|
| 269 |
+
yield "⏳ Computing similarity matrix..."
|
| 270 |
+
similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings)
|
| 271 |
+
|
| 272 |
+
yield "⏳ Computing retrieval metrics..."
|
| 273 |
+
retrieval_metrics = compute_retrieval_metrics(similarity_matrix)
|
| 274 |
+
|
| 275 |
+
yield "⏳ Computing bitext accuracy..."
|
| 276 |
+
bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy(
|
| 277 |
+
similarity_matrix
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Yield final results
|
| 281 |
+
yield EvaluationResults(
|
| 282 |
+
model_id=model.model_id,
|
| 283 |
+
model_name=model.name,
|
| 284 |
+
recall_at_1=retrieval_metrics["recall_at_1"],
|
| 285 |
+
recall_at_5=retrieval_metrics["recall_at_5"],
|
| 286 |
+
recall_at_10=retrieval_metrics["recall_at_10"],
|
| 287 |
+
mrr=retrieval_metrics["mrr"],
|
| 288 |
+
bitext_accuracy=bitext_acc,
|
| 289 |
+
avg_true_pair_similarity=avg_true_sim,
|
| 290 |
+
avg_random_pair_similarity=avg_random_sim,
|
| 291 |
+
num_pairs=len(benchmark_pairs),
|
| 292 |
+
categories=dict(categories),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def evaluate_by_category(
|
| 297 |
+
model,
|
| 298 |
+
benchmark_pairs: list[dict],
|
| 299 |
+
batch_size: int = 32,
|
| 300 |
+
) -> dict[str, EvaluationResults]:
|
| 301 |
+
"""
|
| 302 |
+
Run evaluation broken down by category.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
model: EmbeddingModel instance
|
| 306 |
+
benchmark_pairs: List of benchmark pairs
|
| 307 |
+
batch_size: Batch size for encoding
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
Dict mapping category name to EvaluationResults
|
| 311 |
+
"""
|
| 312 |
+
from collections import defaultdict
|
| 313 |
+
|
| 314 |
+
# Group pairs by category
|
| 315 |
+
by_category = defaultdict(list)
|
| 316 |
+
for pair in benchmark_pairs:
|
| 317 |
+
category = pair.get("category", "Unknown")
|
| 318 |
+
by_category[category].append(pair)
|
| 319 |
+
|
| 320 |
+
results = {}
|
| 321 |
+
for category, pairs in by_category.items():
|
| 322 |
+
print(f"Evaluating category: {category} ({len(pairs)} pairs)")
|
| 323 |
+
results[category] = evaluate_model(model, pairs, batch_size=batch_size)
|
| 324 |
+
|
| 325 |
+
return results
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_rank_distribution(
|
| 329 |
+
similarity_matrix: np.ndarray,
|
| 330 |
+
bins: list[int] = [1, 5, 10, 50, 100],
|
| 331 |
+
) -> dict[str, int]:
|
| 332 |
+
"""
|
| 333 |
+
Get distribution of true match ranks.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
similarity_matrix: (N, N) similarity matrix
|
| 337 |
+
bins: Bin boundaries for histogram
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Dict mapping bin labels to counts
|
| 341 |
+
"""
|
| 342 |
+
n = similarity_matrix.shape[0]
|
| 343 |
+
rankings = np.argsort(-similarity_matrix, axis=1)
|
| 344 |
+
|
| 345 |
+
# Find true rank for each query
|
| 346 |
+
true_ranks = np.zeros(n, dtype=int)
|
| 347 |
+
for i in range(n):
|
| 348 |
+
true_ranks[i] = np.where(rankings[i] == i)[0][0]
|
| 349 |
+
|
| 350 |
+
# Create histogram
|
| 351 |
+
distribution = {}
|
| 352 |
+
prev_bin = 0
|
| 353 |
+
for bin_edge in bins:
|
| 354 |
+
count = np.sum((true_ranks >= prev_bin) & (true_ranks < bin_edge))
|
| 355 |
+
label = f"{prev_bin+1}-{bin_edge}" if prev_bin > 0 else f"Top {bin_edge}"
|
| 356 |
+
distribution[label] = int(count)
|
| 357 |
+
prev_bin = bin_edge
|
| 358 |
+
|
| 359 |
+
# Count remaining
|
| 360 |
+
remaining = np.sum(true_ranks >= bins[-1])
|
| 361 |
+
distribution[f">{bins[-1]}"] = int(remaining)
|
| 362 |
+
|
| 363 |
+
return distribution
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
# Test with sample data
|
| 368 |
+
print("Testing evaluation functions...")
|
| 369 |
+
|
| 370 |
+
# Create sample similarity matrix (perfect retrieval)
|
| 371 |
+
n = 100
|
| 372 |
+
perfect_matrix = np.eye(n) + np.random.randn(n, n) * 0.1
|
| 373 |
+
|
| 374 |
+
metrics = compute_retrieval_metrics(perfect_matrix)
|
| 375 |
+
print(f"Perfect retrieval metrics: {metrics}")
|
| 376 |
+
|
| 377 |
+
# Test with random matrix
|
| 378 |
+
random_matrix = np.random.randn(n, n)
|
| 379 |
+
random_matrix = random_matrix / np.linalg.norm(random_matrix, axis=1, keepdims=True)
|
| 380 |
+
random_matrix = np.dot(random_matrix, random_matrix.T)
|
| 381 |
+
|
| 382 |
+
metrics = compute_retrieval_metrics(random_matrix)
|
| 383 |
+
print(f"Random retrieval metrics: {metrics}")
|
| 384 |
+
|
models.py
ADDED
|
@@ -0,0 +1,1063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model loading and embedding interface for the Rabbinic embedding benchmark.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- Curated models from Hugging Face (sentence-transformers)
|
| 6 |
+
- Any Hugging Face sentence-transformer model
|
| 7 |
+
- API-based models (OpenAI, Voyage AI, Google Gemini)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from typing import Optional
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
# Curated local models known to work well for multilingual tasks
|
| 16 |
+
CURATED_MODELS = {
|
| 17 |
+
"intfloat/multilingual-e5-large": {
|
| 18 |
+
"name": "Multilingual E5 Large",
|
| 19 |
+
"description": "Strong multilingual model from Microsoft, 560M params",
|
| 20 |
+
"type": "local",
|
| 21 |
+
"query_prefix": "query: ",
|
| 22 |
+
"passage_prefix": "passage: ",
|
| 23 |
+
},
|
| 24 |
+
"intfloat/multilingual-e5-base": {
|
| 25 |
+
"name": "Multilingual E5 Base",
|
| 26 |
+
"description": "Smaller multilingual E5, 278M params",
|
| 27 |
+
"type": "local",
|
| 28 |
+
"query_prefix": "query: ",
|
| 29 |
+
"passage_prefix": "passage: ",
|
| 30 |
+
},
|
| 31 |
+
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2": {
|
| 32 |
+
"name": "Multilingual MPNet",
|
| 33 |
+
"description": "Classic multilingual sentence transformer, 278M params",
|
| 34 |
+
"type": "local",
|
| 35 |
+
"query_prefix": "",
|
| 36 |
+
"passage_prefix": "",
|
| 37 |
+
},
|
| 38 |
+
"BAAI/bge-m3": {
|
| 39 |
+
"name": "BGE-M3",
|
| 40 |
+
"description": "Multi-lingual, multi-functionality, multi-granularity model from BAAI",
|
| 41 |
+
"type": "local",
|
| 42 |
+
"query_prefix": "",
|
| 43 |
+
"passage_prefix": "",
|
| 44 |
+
},
|
| 45 |
+
"intfloat/e5-mistral-7b-instruct": {
|
| 46 |
+
"name": "E5 Mistral 7B",
|
| 47 |
+
"description": "Large instruction-tuned embedding model, 7B params (requires GPU)",
|
| 48 |
+
"type": "local",
|
| 49 |
+
"query_prefix": "Instruct: Retrieve semantically similar text\nQuery: ",
|
| 50 |
+
"passage_prefix": "",
|
| 51 |
+
},
|
| 52 |
+
"Alibaba-NLP/gte-multilingual-base": {
|
| 53 |
+
"name": "GTE Multilingual Base",
|
| 54 |
+
"description": "General Text Embeddings multilingual model from Alibaba",
|
| 55 |
+
"type": "local",
|
| 56 |
+
"query_prefix": "",
|
| 57 |
+
"passage_prefix": "",
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# API-based models
|
| 62 |
+
API_MODELS = {
|
| 63 |
+
"openai/text-embedding-3-large": {
|
| 64 |
+
"name": "OpenAI text-embedding-3-large",
|
| 65 |
+
"description": "OpenAI's best embedding model, 3072 dimensions (API key required)",
|
| 66 |
+
"type": "openai",
|
| 67 |
+
"model_name": "text-embedding-3-large",
|
| 68 |
+
"dimensions": 3072,
|
| 69 |
+
},
|
| 70 |
+
"openai/text-embedding-3-small": {
|
| 71 |
+
"name": "OpenAI text-embedding-3-small",
|
| 72 |
+
"description": "OpenAI's efficient embedding model, 1536 dimensions (API key required)",
|
| 73 |
+
"type": "openai",
|
| 74 |
+
"model_name": "text-embedding-3-small",
|
| 75 |
+
"dimensions": 1536,
|
| 76 |
+
},
|
| 77 |
+
"openai/text-embedding-ada-002": {
|
| 78 |
+
"name": "OpenAI Ada 002",
|
| 79 |
+
"description": "OpenAI's legacy embedding model, 1536 dimensions (API key required)",
|
| 80 |
+
"type": "openai",
|
| 81 |
+
"model_name": "text-embedding-ada-002",
|
| 82 |
+
"dimensions": 1536,
|
| 83 |
+
},
|
| 84 |
+
"voyage/voyage-3.5": {
|
| 85 |
+
"name": "Voyage AI voyage-3.5",
|
| 86 |
+
"description": "Voyage AI's latest embedding model (API key required)",
|
| 87 |
+
"type": "voyage",
|
| 88 |
+
"model_name": "voyage-3.5",
|
| 89 |
+
"dimensions": 1024,
|
| 90 |
+
},
|
| 91 |
+
"voyage/voyage-3.5-lite": {
|
| 92 |
+
"name": "Voyage AI voyage-3.5-lite",
|
| 93 |
+
"description": "Voyage AI's efficient embedding model (API key required)",
|
| 94 |
+
"type": "voyage",
|
| 95 |
+
"model_name": "voyage-3.5-lite",
|
| 96 |
+
"dimensions": 1024,
|
| 97 |
+
},
|
| 98 |
+
"voyage/voyage-3": {
|
| 99 |
+
"name": "Voyage AI voyage-3",
|
| 100 |
+
"description": "Voyage AI's general purpose embedding model (API key required)",
|
| 101 |
+
"type": "voyage",
|
| 102 |
+
"model_name": "voyage-3",
|
| 103 |
+
"dimensions": 1024,
|
| 104 |
+
},
|
| 105 |
+
"voyage/voyage-3-lite": {
|
| 106 |
+
"name": "Voyage AI voyage-3-lite",
|
| 107 |
+
"description": "Voyage AI's lightweight embedding model (API key required)",
|
| 108 |
+
"type": "voyage",
|
| 109 |
+
"model_name": "voyage-3-lite",
|
| 110 |
+
"dimensions": 512,
|
| 111 |
+
},
|
| 112 |
+
"voyage/voyage-multilingual-2": {
|
| 113 |
+
"name": "Voyage AI voyage-multilingual-2",
|
| 114 |
+
"description": "Voyage AI's multilingual embedding model, optimized for non-English (API key required)",
|
| 115 |
+
"type": "voyage",
|
| 116 |
+
"model_name": "voyage-multilingual-2",
|
| 117 |
+
"dimensions": 1024,
|
| 118 |
+
},
|
| 119 |
+
"gemini/gemini-embedding-001": {
|
| 120 |
+
"name": "Gemini Embedding 001",
|
| 121 |
+
"description": "Google's Gemini embedding model, 3072 dimensions (API key required)",
|
| 122 |
+
"type": "gemini",
|
| 123 |
+
"model_name": "gemini-embedding-001",
|
| 124 |
+
"dimensions": 3072,
|
| 125 |
+
},
|
| 126 |
+
"gemini/gemini-embedding-001-768": {
|
| 127 |
+
"name": "Gemini Embedding 001 (768d)",
|
| 128 |
+
"description": "Google's Gemini embedding model, 768 dimensions (API key required)",
|
| 129 |
+
"type": "gemini",
|
| 130 |
+
"model_name": "gemini-embedding-001",
|
| 131 |
+
"dimensions": 768,
|
| 132 |
+
},
|
| 133 |
+
"gemini/gemini-embedding-001-1536": {
|
| 134 |
+
"name": "Gemini Embedding 001 (1536d)",
|
| 135 |
+
"description": "Google's Gemini embedding model, 1536 dimensions (API key required)",
|
| 136 |
+
"type": "gemini",
|
| 137 |
+
"model_name": "gemini-embedding-001",
|
| 138 |
+
"dimensions": 1536,
|
| 139 |
+
},
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# Merge all models for easy lookup
|
| 143 |
+
ALL_MODELS = {**CURATED_MODELS, **API_MODELS}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class BaseEmbeddingModel(ABC):
|
| 147 |
+
"""Abstract base class for embedding models."""
|
| 148 |
+
|
| 149 |
+
model_id: str
|
| 150 |
+
embedding_dim: int
|
| 151 |
+
|
| 152 |
+
@abstractmethod
|
| 153 |
+
def encode(
|
| 154 |
+
self,
|
| 155 |
+
texts: list[str],
|
| 156 |
+
is_query: bool = False,
|
| 157 |
+
batch_size: int = 32,
|
| 158 |
+
show_progress: bool = True,
|
| 159 |
+
normalize: bool = True,
|
| 160 |
+
) -> np.ndarray:
|
| 161 |
+
"""Encode texts to embeddings."""
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def name(self) -> str:
|
| 167 |
+
"""Get display name for the model."""
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
@abstractmethod
|
| 172 |
+
def description(self) -> str:
|
| 173 |
+
"""Get description for the model."""
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
def encode_pairs(
|
| 177 |
+
self,
|
| 178 |
+
he_texts: list[str],
|
| 179 |
+
en_texts: list[str],
|
| 180 |
+
batch_size: int = 32,
|
| 181 |
+
show_progress: bool = True,
|
| 182 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 183 |
+
"""
|
| 184 |
+
Encode parallel Hebrew/English text pairs.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
he_texts: Hebrew/Aramaic source texts
|
| 188 |
+
en_texts: English translations
|
| 189 |
+
batch_size: Batch size for encoding
|
| 190 |
+
show_progress: Whether to show progress bar
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tuple of (hebrew_embeddings, english_embeddings)
|
| 194 |
+
"""
|
| 195 |
+
he_embeddings = self.encode(
|
| 196 |
+
he_texts,
|
| 197 |
+
is_query=True,
|
| 198 |
+
batch_size=batch_size,
|
| 199 |
+
show_progress=show_progress,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
en_embeddings = self.encode(
|
| 203 |
+
en_texts,
|
| 204 |
+
is_query=False,
|
| 205 |
+
batch_size=batch_size,
|
| 206 |
+
show_progress=show_progress,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return he_embeddings, en_embeddings
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class EmbeddingModel(BaseEmbeddingModel):
|
| 213 |
+
"""
|
| 214 |
+
Wrapper for sentence-transformer models with consistent interface.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
model_id: str,
|
| 220 |
+
device: Optional[str] = None,
|
| 221 |
+
max_length: int = 512,
|
| 222 |
+
):
|
| 223 |
+
"""
|
| 224 |
+
Initialize the embedding model.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
model_id: Hugging Face model ID
|
| 228 |
+
device: Device to use ('cuda', 'cpu', or None for auto)
|
| 229 |
+
max_length: Maximum sequence length for tokenization
|
| 230 |
+
"""
|
| 231 |
+
from sentence_transformers import SentenceTransformer
|
| 232 |
+
import torch
|
| 233 |
+
|
| 234 |
+
self.model_id = model_id
|
| 235 |
+
self.max_length = max_length
|
| 236 |
+
|
| 237 |
+
# Auto-detect device
|
| 238 |
+
if device is None:
|
| 239 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 240 |
+
self.device = device
|
| 241 |
+
|
| 242 |
+
# Get model config if it's a curated model
|
| 243 |
+
self.config = CURATED_MODELS.get(model_id, {
|
| 244 |
+
"name": model_id.split("/")[-1],
|
| 245 |
+
"description": "Custom model",
|
| 246 |
+
"type": "local",
|
| 247 |
+
"query_prefix": "",
|
| 248 |
+
"passage_prefix": "",
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
# Load the model
|
| 252 |
+
print(f"Loading model: {model_id} on {device}")
|
| 253 |
+
self.model = SentenceTransformer(model_id, device=device)
|
| 254 |
+
|
| 255 |
+
# Set max sequence length if supported
|
| 256 |
+
if hasattr(self.model, "max_seq_length"):
|
| 257 |
+
self.model.max_seq_length = min(max_length, self.model.max_seq_length)
|
| 258 |
+
|
| 259 |
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 260 |
+
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
| 261 |
+
|
| 262 |
+
def encode(
|
| 263 |
+
self,
|
| 264 |
+
texts: list[str],
|
| 265 |
+
is_query: bool = False,
|
| 266 |
+
batch_size: int = 32,
|
| 267 |
+
show_progress: bool = True,
|
| 268 |
+
normalize: bool = True,
|
| 269 |
+
) -> np.ndarray:
|
| 270 |
+
"""
|
| 271 |
+
Encode texts to embeddings.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
texts: List of texts to encode
|
| 275 |
+
is_query: Whether these are queries (vs passages) for asymmetric models
|
| 276 |
+
batch_size: Batch size for encoding
|
| 277 |
+
show_progress: Whether to show progress bar
|
| 278 |
+
normalize: Whether to L2-normalize embeddings
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 282 |
+
"""
|
| 283 |
+
# Add prefix if needed (for E5-style models)
|
| 284 |
+
prefix = self.config["query_prefix"] if is_query else self.config["passage_prefix"]
|
| 285 |
+
if prefix:
|
| 286 |
+
texts = [prefix + t for t in texts]
|
| 287 |
+
|
| 288 |
+
embeddings = self.model.encode(
|
| 289 |
+
texts,
|
| 290 |
+
batch_size=batch_size,
|
| 291 |
+
show_progress_bar=show_progress,
|
| 292 |
+
normalize_embeddings=normalize,
|
| 293 |
+
convert_to_numpy=True,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
return embeddings
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def name(self) -> str:
|
| 300 |
+
"""Get display name for the model."""
|
| 301 |
+
return self.config.get("name", self.model_id)
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def description(self) -> str:
|
| 305 |
+
"""Get description for the model."""
|
| 306 |
+
return self.config.get("description", "")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class OpenAIEmbeddingModel(BaseEmbeddingModel):
|
| 310 |
+
"""
|
| 311 |
+
Wrapper for OpenAI embedding API with consistent interface.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
# OpenAI embedding models have an 8191 token limit
|
| 315 |
+
MAX_TOKENS = 8191
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
model_id: str,
|
| 320 |
+
api_key: Optional[str] = None,
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
Initialize the OpenAI embedding model.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
model_id: Model ID in format 'openai/model-name'
|
| 327 |
+
api_key: OpenAI API key (or uses OPENAI_API_KEY env var)
|
| 328 |
+
"""
|
| 329 |
+
try:
|
| 330 |
+
from openai import OpenAI
|
| 331 |
+
except ImportError:
|
| 332 |
+
raise ImportError(
|
| 333 |
+
"OpenAI package not installed. Install with: pip install openai"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.model_id = model_id
|
| 337 |
+
|
| 338 |
+
# Get API key from parameter or environment
|
| 339 |
+
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
| 340 |
+
if not api_key:
|
| 341 |
+
raise ValueError(
|
| 342 |
+
"OpenAI API key required. Set OPENAI_API_KEY environment variable "
|
| 343 |
+
"or pass api_key parameter."
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.client = OpenAI(api_key=api_key)
|
| 347 |
+
|
| 348 |
+
# Get model config
|
| 349 |
+
self.config = API_MODELS.get(model_id, {
|
| 350 |
+
"name": model_id,
|
| 351 |
+
"description": "OpenAI embedding model",
|
| 352 |
+
"type": "openai",
|
| 353 |
+
"model_name": model_id.replace("openai/", ""),
|
| 354 |
+
"dimensions": 1536,
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
self._model_name = self.config["model_name"]
|
| 358 |
+
self.embedding_dim = self.config["dimensions"]
|
| 359 |
+
|
| 360 |
+
# Initialize tokenizer for truncation
|
| 361 |
+
self._encoding = None
|
| 362 |
+
try:
|
| 363 |
+
import tiktoken
|
| 364 |
+
self._encoding = tiktoken.encoding_for_model(self._model_name)
|
| 365 |
+
except Exception:
|
| 366 |
+
# Fall back to cl100k_base which is used by embedding models
|
| 367 |
+
try:
|
| 368 |
+
import tiktoken
|
| 369 |
+
self._encoding = tiktoken.get_encoding("cl100k_base")
|
| 370 |
+
except Exception:
|
| 371 |
+
print("Warning: tiktoken not available, using character-based truncation")
|
| 372 |
+
|
| 373 |
+
print(f"Initialized OpenAI embedding model: {self._model_name}")
|
| 374 |
+
print(f"Embedding dimension: {self.embedding_dim}")
|
| 375 |
+
|
| 376 |
+
def _truncate_text(self, text: str) -> str:
|
| 377 |
+
"""Truncate text to fit within token limit."""
|
| 378 |
+
if self._encoding is not None:
|
| 379 |
+
# Use tiktoken for accurate token counting
|
| 380 |
+
tokens = self._encoding.encode(text)
|
| 381 |
+
if len(tokens) > self.MAX_TOKENS:
|
| 382 |
+
tokens = tokens[:self.MAX_TOKENS]
|
| 383 |
+
return self._encoding.decode(tokens)
|
| 384 |
+
return text
|
| 385 |
+
else:
|
| 386 |
+
# Fallback: rough character-based truncation
|
| 387 |
+
# Assume ~3 chars per token for Hebrew/mixed text (conservative)
|
| 388 |
+
max_chars = self.MAX_TOKENS * 3
|
| 389 |
+
if len(text) > max_chars:
|
| 390 |
+
return text[:max_chars]
|
| 391 |
+
return text
|
| 392 |
+
|
| 393 |
+
def encode(
|
| 394 |
+
self,
|
| 395 |
+
texts: list[str],
|
| 396 |
+
is_query: bool = False,
|
| 397 |
+
batch_size: int = 100, # OpenAI supports larger batches
|
| 398 |
+
show_progress: bool = True,
|
| 399 |
+
normalize: bool = True,
|
| 400 |
+
) -> np.ndarray:
|
| 401 |
+
"""
|
| 402 |
+
Encode texts to embeddings using OpenAI API.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
texts: List of texts to encode
|
| 406 |
+
is_query: Not used for OpenAI (symmetric embeddings)
|
| 407 |
+
batch_size: Batch size for API calls
|
| 408 |
+
show_progress: Whether to show progress bar
|
| 409 |
+
normalize: Whether to L2-normalize embeddings (OpenAI already normalizes)
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 413 |
+
"""
|
| 414 |
+
import time
|
| 415 |
+
|
| 416 |
+
all_embeddings = []
|
| 417 |
+
total_batches = (len(texts) + batch_size - 1) // batch_size
|
| 418 |
+
|
| 419 |
+
for i in range(0, len(texts), batch_size):
|
| 420 |
+
batch = texts[i:i + batch_size]
|
| 421 |
+
batch_num = i // batch_size + 1
|
| 422 |
+
|
| 423 |
+
if show_progress:
|
| 424 |
+
print(f" Encoding batch {batch_num}/{total_batches}...")
|
| 425 |
+
|
| 426 |
+
# Retry logic for API calls
|
| 427 |
+
max_retries = 3
|
| 428 |
+
for attempt in range(max_retries):
|
| 429 |
+
try:
|
| 430 |
+
response = self.client.embeddings.create(
|
| 431 |
+
model=self._model_name,
|
| 432 |
+
input=batch,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Extract embeddings from response
|
| 436 |
+
batch_embeddings = [item.embedding for item in response.data]
|
| 437 |
+
all_embeddings.extend(batch_embeddings)
|
| 438 |
+
break
|
| 439 |
+
|
| 440 |
+
except Exception as e:
|
| 441 |
+
if attempt < max_retries - 1:
|
| 442 |
+
wait_time = 2 ** attempt
|
| 443 |
+
print(f" API error, retrying in {wait_time}s: {e}")
|
| 444 |
+
time.sleep(wait_time)
|
| 445 |
+
else:
|
| 446 |
+
raise RuntimeError(f"OpenAI API error after {max_retries} retries: {e}")
|
| 447 |
+
|
| 448 |
+
# Small delay to avoid rate limits
|
| 449 |
+
if i + batch_size < len(texts):
|
| 450 |
+
time.sleep(0.1)
|
| 451 |
+
|
| 452 |
+
embeddings = np.array(all_embeddings, dtype=np.float32)
|
| 453 |
+
|
| 454 |
+
# OpenAI embeddings are already normalized, but normalize if requested
|
| 455 |
+
if normalize:
|
| 456 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 457 |
+
embeddings = embeddings / np.maximum(norms, 1e-10)
|
| 458 |
+
|
| 459 |
+
return embeddings
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def name(self) -> str:
|
| 463 |
+
"""Get display name for the model."""
|
| 464 |
+
return self.config.get("name", self.model_id)
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def description(self) -> str:
|
| 468 |
+
"""Get description for the model."""
|
| 469 |
+
return self.config.get("description", "")
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class VoyageEmbeddingModel(BaseEmbeddingModel):
|
| 473 |
+
"""
|
| 474 |
+
Wrapper for Voyage AI embedding API with consistent interface.
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
def __init__(
|
| 478 |
+
self,
|
| 479 |
+
model_id: str,
|
| 480 |
+
api_key: Optional[str] = None,
|
| 481 |
+
):
|
| 482 |
+
"""
|
| 483 |
+
Initialize the Voyage AI embedding model.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
model_id: Model ID in format 'voyage/model-name'
|
| 487 |
+
api_key: Voyage API key (or uses VOYAGE_API_KEY env var)
|
| 488 |
+
"""
|
| 489 |
+
try:
|
| 490 |
+
import voyageai
|
| 491 |
+
except ImportError:
|
| 492 |
+
raise ImportError(
|
| 493 |
+
"Voyage AI package not installed. Install with: pip install voyageai"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
self.model_id = model_id
|
| 497 |
+
|
| 498 |
+
# Get API key from parameter or environment
|
| 499 |
+
api_key = api_key or os.environ.get("VOYAGE_API_KEY")
|
| 500 |
+
if not api_key:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
"Voyage API key required. Set VOYAGE_API_KEY environment variable "
|
| 503 |
+
"or pass api_key parameter."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
self.client = voyageai.Client(api_key=api_key)
|
| 507 |
+
|
| 508 |
+
# Get model config
|
| 509 |
+
self.config = API_MODELS.get(model_id, {
|
| 510 |
+
"name": model_id,
|
| 511 |
+
"description": "Voyage AI embedding model",
|
| 512 |
+
"type": "voyage",
|
| 513 |
+
"model_name": model_id.replace("voyage/", ""),
|
| 514 |
+
"dimensions": 1024, # Default dimension
|
| 515 |
+
})
|
| 516 |
+
|
| 517 |
+
self._model_name = self.config["model_name"]
|
| 518 |
+
self.embedding_dim = self.config["dimensions"]
|
| 519 |
+
|
| 520 |
+
print(f"Initialized Voyage AI embedding model: {self._model_name}")
|
| 521 |
+
print(f"Embedding dimension: {self.embedding_dim}")
|
| 522 |
+
|
| 523 |
+
def encode(
|
| 524 |
+
self,
|
| 525 |
+
texts: list[str],
|
| 526 |
+
is_query: bool = False,
|
| 527 |
+
batch_size: int = 128, # Voyage supports larger batches
|
| 528 |
+
show_progress: bool = True,
|
| 529 |
+
normalize: bool = True,
|
| 530 |
+
) -> np.ndarray:
|
| 531 |
+
"""
|
| 532 |
+
Encode texts to embeddings using Voyage AI API.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
texts: List of texts to encode
|
| 536 |
+
is_query: Whether these are queries (Voyage supports input_type)
|
| 537 |
+
batch_size: Batch size for API calls
|
| 538 |
+
show_progress: Whether to show progress bar
|
| 539 |
+
normalize: Whether to L2-normalize embeddings
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 543 |
+
"""
|
| 544 |
+
import time
|
| 545 |
+
|
| 546 |
+
all_embeddings = []
|
| 547 |
+
total_batches = (len(texts) + batch_size - 1) // batch_size
|
| 548 |
+
|
| 549 |
+
# Voyage supports input_type for asymmetric embeddings
|
| 550 |
+
input_type = "query" if is_query else "document"
|
| 551 |
+
|
| 552 |
+
for i in range(0, len(texts), batch_size):
|
| 553 |
+
batch = texts[i:i + batch_size]
|
| 554 |
+
batch_num = i // batch_size + 1
|
| 555 |
+
|
| 556 |
+
if show_progress:
|
| 557 |
+
print(f" Encoding batch {batch_num}/{total_batches}...")
|
| 558 |
+
|
| 559 |
+
# Retry logic for API calls
|
| 560 |
+
max_retries = 3
|
| 561 |
+
for attempt in range(max_retries):
|
| 562 |
+
try:
|
| 563 |
+
result = self.client.embed(
|
| 564 |
+
batch,
|
| 565 |
+
model=self._model_name,
|
| 566 |
+
input_type=input_type,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Extract embeddings from response
|
| 570 |
+
batch_embeddings = result.embeddings
|
| 571 |
+
all_embeddings.extend(batch_embeddings)
|
| 572 |
+
break
|
| 573 |
+
|
| 574 |
+
except Exception as e:
|
| 575 |
+
if attempt < max_retries - 1:
|
| 576 |
+
wait_time = 2 ** attempt
|
| 577 |
+
print(f" API error, retrying in {wait_time}s: {e}")
|
| 578 |
+
time.sleep(wait_time)
|
| 579 |
+
else:
|
| 580 |
+
raise RuntimeError(f"Voyage AI API error after {max_retries} retries: {e}")
|
| 581 |
+
|
| 582 |
+
# Small delay to avoid rate limits
|
| 583 |
+
if i + batch_size < len(texts):
|
| 584 |
+
time.sleep(0.1)
|
| 585 |
+
|
| 586 |
+
embeddings = np.array(all_embeddings, dtype=np.float32)
|
| 587 |
+
|
| 588 |
+
# Normalize if requested
|
| 589 |
+
if normalize:
|
| 590 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 591 |
+
embeddings = embeddings / np.maximum(norms, 1e-10)
|
| 592 |
+
|
| 593 |
+
return embeddings
|
| 594 |
+
|
| 595 |
+
@property
|
| 596 |
+
def name(self) -> str:
|
| 597 |
+
"""Get display name for the model."""
|
| 598 |
+
return self.config.get("name", self.model_id)
|
| 599 |
+
|
| 600 |
+
@property
|
| 601 |
+
def description(self) -> str:
|
| 602 |
+
"""Get description for the model."""
|
| 603 |
+
return self.config.get("description", "")
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class GeminiEmbeddingModel(BaseEmbeddingModel):
|
| 607 |
+
"""
|
| 608 |
+
Wrapper for Google Gemini embedding API with consistent interface.
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
def __init__(
|
| 612 |
+
self,
|
| 613 |
+
model_id: str,
|
| 614 |
+
api_key: Optional[str] = None,
|
| 615 |
+
):
|
| 616 |
+
"""
|
| 617 |
+
Initialize the Gemini embedding model.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
model_id: Model ID in format 'gemini/model-name'
|
| 621 |
+
api_key: Gemini API key (optional - can use GEMINI_API_KEY env var
|
| 622 |
+
or Google Cloud Application Default Credentials)
|
| 623 |
+
"""
|
| 624 |
+
try:
|
| 625 |
+
from google import genai
|
| 626 |
+
except ImportError:
|
| 627 |
+
raise ImportError(
|
| 628 |
+
"Google GenAI package not installed. Install with: pip install google-genai"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
self.model_id = model_id
|
| 632 |
+
|
| 633 |
+
# Get API key from parameter or environment (optional - ADC also works)
|
| 634 |
+
api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
| 635 |
+
|
| 636 |
+
# Create client - if no API key, will use Application Default Credentials
|
| 637 |
+
if api_key:
|
| 638 |
+
self.client = genai.Client(api_key=api_key)
|
| 639 |
+
else:
|
| 640 |
+
# Use Application Default Credentials (gcloud auth application-default login)
|
| 641 |
+
self.client = genai.Client()
|
| 642 |
+
|
| 643 |
+
# Get model config
|
| 644 |
+
self.config = API_MODELS.get(model_id, {
|
| 645 |
+
"name": model_id,
|
| 646 |
+
"description": "Gemini embedding model",
|
| 647 |
+
"type": "gemini",
|
| 648 |
+
"model_name": model_id.replace("gemini/", "").split("-768")[0].split("-1536")[0],
|
| 649 |
+
"dimensions": 3072, # Default dimension
|
| 650 |
+
})
|
| 651 |
+
|
| 652 |
+
self._model_name = self.config["model_name"]
|
| 653 |
+
self.embedding_dim = self.config["dimensions"]
|
| 654 |
+
|
| 655 |
+
print(f"Initialized Gemini embedding model: {self._model_name}")
|
| 656 |
+
print(f"Embedding dimension: {self.embedding_dim}")
|
| 657 |
+
|
| 658 |
+
def encode(
|
| 659 |
+
self,
|
| 660 |
+
texts: list[str],
|
| 661 |
+
is_query: bool = False,
|
| 662 |
+
batch_size: int = 20, # Smaller batches to avoid rate limits
|
| 663 |
+
show_progress: bool = True,
|
| 664 |
+
normalize: bool = True,
|
| 665 |
+
) -> np.ndarray:
|
| 666 |
+
"""
|
| 667 |
+
Encode texts to embeddings using Gemini API.
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
texts: List of texts to encode
|
| 671 |
+
is_query: Whether these are queries (uses RETRIEVAL_QUERY vs RETRIEVAL_DOCUMENT)
|
| 672 |
+
batch_size: Batch size for API calls (smaller for Gemini to avoid rate limits)
|
| 673 |
+
show_progress: Whether to show progress bar
|
| 674 |
+
normalize: Whether to L2-normalize embeddings
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 678 |
+
"""
|
| 679 |
+
import time
|
| 680 |
+
import random
|
| 681 |
+
from google.genai import types
|
| 682 |
+
|
| 683 |
+
all_embeddings = []
|
| 684 |
+
total_batches = (len(texts) + batch_size - 1) // batch_size
|
| 685 |
+
|
| 686 |
+
# Gemini supports task_type for asymmetric embeddings
|
| 687 |
+
task_type = "RETRIEVAL_QUERY" if is_query else "RETRIEVAL_DOCUMENT"
|
| 688 |
+
|
| 689 |
+
for i in range(0, len(texts), batch_size):
|
| 690 |
+
batch = texts[i:i + batch_size]
|
| 691 |
+
batch_num = i // batch_size + 1
|
| 692 |
+
|
| 693 |
+
if show_progress:
|
| 694 |
+
print(f" Encoding batch {batch_num}/{total_batches}...")
|
| 695 |
+
|
| 696 |
+
# Retry logic with exponential backoff for rate limits
|
| 697 |
+
max_retries = 8
|
| 698 |
+
base_delay = 2.0
|
| 699 |
+
|
| 700 |
+
for attempt in range(max_retries):
|
| 701 |
+
try:
|
| 702 |
+
# Build config with task type and output dimensionality
|
| 703 |
+
embed_config = types.EmbedContentConfig(
|
| 704 |
+
task_type=task_type,
|
| 705 |
+
output_dimensionality=self.embedding_dim,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
result = self.client.models.embed_content(
|
| 709 |
+
model=self._model_name,
|
| 710 |
+
contents=batch,
|
| 711 |
+
config=embed_config,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Extract embeddings from response
|
| 715 |
+
batch_embeddings = [e.values for e in result.embeddings]
|
| 716 |
+
all_embeddings.extend(batch_embeddings)
|
| 717 |
+
break
|
| 718 |
+
|
| 719 |
+
except Exception as e:
|
| 720 |
+
error_str = str(e)
|
| 721 |
+
is_rate_limit = "429" in error_str or "RESOURCE_EXHAUSTED" in error_str
|
| 722 |
+
|
| 723 |
+
if attempt < max_retries - 1:
|
| 724 |
+
# Exponential backoff with jitter
|
| 725 |
+
# Longer waits for rate limit errors
|
| 726 |
+
if is_rate_limit:
|
| 727 |
+
wait_time = base_delay * (2 ** attempt) + random.uniform(1, 5)
|
| 728 |
+
print(f" Rate limited, waiting {wait_time:.1f}s before retry {attempt + 2}/{max_retries}...")
|
| 729 |
+
else:
|
| 730 |
+
wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1)
|
| 731 |
+
print(f" API error, retrying in {wait_time:.1f}s: {e}")
|
| 732 |
+
time.sleep(wait_time)
|
| 733 |
+
else:
|
| 734 |
+
raise RuntimeError(f"Gemini API error after {max_retries} retries: {e}")
|
| 735 |
+
|
| 736 |
+
# Delay between batches to avoid rate limits (longer for Gemini)
|
| 737 |
+
if i + batch_size < len(texts):
|
| 738 |
+
time.sleep(0.5)
|
| 739 |
+
|
| 740 |
+
embeddings = np.array(all_embeddings, dtype=np.float32)
|
| 741 |
+
|
| 742 |
+
# Normalize if requested (Gemini's 3072d is normalized, but smaller dims need it)
|
| 743 |
+
if normalize:
|
| 744 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 745 |
+
embeddings = embeddings / np.maximum(norms, 1e-10)
|
| 746 |
+
|
| 747 |
+
return embeddings
|
| 748 |
+
|
| 749 |
+
@property
|
| 750 |
+
def name(self) -> str:
|
| 751 |
+
"""Get display name for the model."""
|
| 752 |
+
return self.config.get("name", self.model_id)
|
| 753 |
+
|
| 754 |
+
@property
|
| 755 |
+
def description(self) -> str:
|
| 756 |
+
"""Get description for the model."""
|
| 757 |
+
return self.config.get("description", "")
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def get_curated_model_choices() -> list[tuple[str, str]]:
|
| 761 |
+
"""
|
| 762 |
+
Get list of curated local models for UI dropdown.
|
| 763 |
+
|
| 764 |
+
Returns:
|
| 765 |
+
List of (model_id, display_name) tuples
|
| 766 |
+
"""
|
| 767 |
+
return [
|
| 768 |
+
(model_id, f"{info['name']} - {info['description']}")
|
| 769 |
+
for model_id, info in CURATED_MODELS.items()
|
| 770 |
+
]
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
def get_api_model_choices() -> list[tuple[str, str]]:
|
| 774 |
+
"""
|
| 775 |
+
Get list of API-based models for UI dropdown.
|
| 776 |
+
|
| 777 |
+
Returns:
|
| 778 |
+
List of (model_id, display_name) tuples
|
| 779 |
+
"""
|
| 780 |
+
return [
|
| 781 |
+
(model_id, f"{info['name']} - {info['description']}")
|
| 782 |
+
for model_id, info in API_MODELS.items()
|
| 783 |
+
]
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def get_all_model_choices() -> list[tuple[str, str]]:
|
| 787 |
+
"""
|
| 788 |
+
Get list of all models (local + API) for UI dropdown.
|
| 789 |
+
|
| 790 |
+
Returns:
|
| 791 |
+
List of (model_id, display_name) tuples
|
| 792 |
+
"""
|
| 793 |
+
return get_curated_model_choices() + get_api_model_choices()
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def is_api_model(model_id: str) -> bool:
|
| 797 |
+
"""Check if a model ID is an API-based model."""
|
| 798 |
+
model_id = model_id.strip()
|
| 799 |
+
|
| 800 |
+
# Check if it's in API_MODELS
|
| 801 |
+
if model_id in API_MODELS:
|
| 802 |
+
return True
|
| 803 |
+
|
| 804 |
+
# Check if it starts with known API prefixes
|
| 805 |
+
if model_id.startswith("openai/"):
|
| 806 |
+
return True
|
| 807 |
+
if model_id.startswith("voyage/"):
|
| 808 |
+
return True
|
| 809 |
+
if model_id.startswith("gemini/"):
|
| 810 |
+
return True
|
| 811 |
+
|
| 812 |
+
return False
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def load_model(
|
| 816 |
+
model_id: str,
|
| 817 |
+
device: Optional[str] = None,
|
| 818 |
+
api_key: Optional[str] = None,
|
| 819 |
+
) -> BaseEmbeddingModel:
|
| 820 |
+
"""
|
| 821 |
+
Load an embedding model by ID.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
|
| 825 |
+
device: Device to use (for local models only)
|
| 826 |
+
api_key: API key (for API-based models, or uses environment variable)
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
Loaded embedding model instance
|
| 830 |
+
"""
|
| 831 |
+
model_id = model_id.strip()
|
| 832 |
+
|
| 833 |
+
# Check if this is an API model
|
| 834 |
+
if is_api_model(model_id):
|
| 835 |
+
# Check model type from config or prefix
|
| 836 |
+
model_config = API_MODELS.get(model_id, {})
|
| 837 |
+
model_type = model_config.get("type", "")
|
| 838 |
+
|
| 839 |
+
if model_type == "voyage" or model_id.startswith("voyage/"):
|
| 840 |
+
return VoyageEmbeddingModel(model_id, api_key=api_key)
|
| 841 |
+
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 842 |
+
return GeminiEmbeddingModel(model_id, api_key=api_key)
|
| 843 |
+
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 844 |
+
return OpenAIEmbeddingModel(model_id, api_key=api_key)
|
| 845 |
+
else:
|
| 846 |
+
raise ValueError(f"Unknown API model type: {model_id}")
|
| 847 |
+
|
| 848 |
+
# Otherwise, load as a local sentence-transformer model
|
| 849 |
+
return EmbeddingModel(model_id, device=device)
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def validate_model_id(model_id: str) -> tuple[bool, str]:
|
| 853 |
+
"""
|
| 854 |
+
Check if a model ID is valid and loadable.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
model_id: The model ID to validate
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
Tuple of (is_valid, error_message)
|
| 861 |
+
"""
|
| 862 |
+
if not model_id or not model_id.strip():
|
| 863 |
+
return False, "Model ID cannot be empty"
|
| 864 |
+
|
| 865 |
+
model_id = model_id.strip()
|
| 866 |
+
|
| 867 |
+
# Check if it's a curated local model
|
| 868 |
+
if model_id in CURATED_MODELS:
|
| 869 |
+
return True, ""
|
| 870 |
+
|
| 871 |
+
# Check if it's a known API model
|
| 872 |
+
if model_id in API_MODELS:
|
| 873 |
+
return True, ""
|
| 874 |
+
|
| 875 |
+
# Check for OpenAI models
|
| 876 |
+
if model_id.startswith("openai/"):
|
| 877 |
+
return True, ""
|
| 878 |
+
|
| 879 |
+
# Check for Voyage AI models
|
| 880 |
+
if model_id.startswith("voyage/"):
|
| 881 |
+
return True, ""
|
| 882 |
+
|
| 883 |
+
# Check for Gemini models
|
| 884 |
+
if model_id.startswith("gemini/"):
|
| 885 |
+
return True, ""
|
| 886 |
+
|
| 887 |
+
# For custom models, check if it looks like a valid HF model ID
|
| 888 |
+
if "/" not in model_id:
|
| 889 |
+
return False, "Model ID should be in format 'organization/model-name'"
|
| 890 |
+
|
| 891 |
+
# Could add an API check here, but that would slow down validation
|
| 892 |
+
return True, ""
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def requires_api_key(model_id: str) -> bool:
|
| 896 |
+
"""Check if a model requires an API key."""
|
| 897 |
+
return is_api_model(model_id)
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def api_key_optional(model_id: str) -> bool:
|
| 901 |
+
"""
|
| 902 |
+
Check if an API key is optional for this model.
|
| 903 |
+
|
| 904 |
+
Some providers (like Google Gemini) support Application Default Credentials
|
| 905 |
+
as an alternative to explicit API keys.
|
| 906 |
+
"""
|
| 907 |
+
key_type = get_api_key_type(model_id)
|
| 908 |
+
# Gemini supports ADC (gcloud auth application-default login)
|
| 909 |
+
return key_type == "gemini"
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def get_api_key_type(model_id: str) -> Optional[str]:
|
| 913 |
+
"""
|
| 914 |
+
Get the type of API key required for a model.
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
model_id: The model ID
|
| 918 |
+
|
| 919 |
+
Returns:
|
| 920 |
+
'openai', 'voyage', or None if no API key needed
|
| 921 |
+
"""
|
| 922 |
+
if not is_api_model(model_id):
|
| 923 |
+
return None
|
| 924 |
+
|
| 925 |
+
model_id = model_id.strip()
|
| 926 |
+
model_config = API_MODELS.get(model_id, {})
|
| 927 |
+
model_type = model_config.get("type", "")
|
| 928 |
+
|
| 929 |
+
if model_type == "voyage" or model_id.startswith("voyage/"):
|
| 930 |
+
return "voyage"
|
| 931 |
+
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 932 |
+
return "gemini"
|
| 933 |
+
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 934 |
+
return "openai"
|
| 935 |
+
|
| 936 |
+
return None
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def get_api_key_env_var(model_id: str) -> Optional[str]:
|
| 940 |
+
"""
|
| 941 |
+
Get the environment variable name for the API key required by a model.
|
| 942 |
+
|
| 943 |
+
Args:
|
| 944 |
+
model_id: The model ID
|
| 945 |
+
|
| 946 |
+
Returns:
|
| 947 |
+
Environment variable name or None
|
| 948 |
+
"""
|
| 949 |
+
key_type = get_api_key_type(model_id)
|
| 950 |
+
if key_type == "openai":
|
| 951 |
+
return "OPENAI_API_KEY"
|
| 952 |
+
elif key_type == "voyage":
|
| 953 |
+
return "VOYAGE_API_KEY"
|
| 954 |
+
elif key_type == "gemini":
|
| 955 |
+
return "GEMINI_API_KEY"
|
| 956 |
+
return None
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
if __name__ == "__main__":
|
| 960 |
+
import argparse
|
| 961 |
+
|
| 962 |
+
parser = argparse.ArgumentParser(
|
| 963 |
+
description="Test embedding model loading and encoding"
|
| 964 |
+
)
|
| 965 |
+
parser.add_argument(
|
| 966 |
+
"--local",
|
| 967 |
+
action="store_true",
|
| 968 |
+
help="Test only local sentence-transformer models",
|
| 969 |
+
)
|
| 970 |
+
parser.add_argument(
|
| 971 |
+
"--remote",
|
| 972 |
+
action="store_true",
|
| 973 |
+
help="Test only remote/API models (requires API keys)",
|
| 974 |
+
)
|
| 975 |
+
parser.add_argument(
|
| 976 |
+
"--model",
|
| 977 |
+
type=str,
|
| 978 |
+
default=None,
|
| 979 |
+
help="Test a specific model ID",
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
args = parser.parse_args()
|
| 983 |
+
|
| 984 |
+
# If neither flag specified, test both
|
| 985 |
+
test_local = args.local or (not args.local and not args.remote)
|
| 986 |
+
test_remote = args.remote or (not args.local and not args.remote)
|
| 987 |
+
|
| 988 |
+
print("Testing model loading...")
|
| 989 |
+
|
| 990 |
+
print(f"\nLocal models available:")
|
| 991 |
+
for model_id, display in get_curated_model_choices():
|
| 992 |
+
print(f" - {display}")
|
| 993 |
+
|
| 994 |
+
print(f"\nAPI models available:")
|
| 995 |
+
for model_id, display in get_api_model_choices():
|
| 996 |
+
print(f" - {display}")
|
| 997 |
+
|
| 998 |
+
# Test texts
|
| 999 |
+
test_texts = [
|
| 1000 |
+
"בראשית ברא אלהים את השמים ואת הארץ",
|
| 1001 |
+
"In the beginning God created the heaven and the earth",
|
| 1002 |
+
]
|
| 1003 |
+
|
| 1004 |
+
def run_model_test(model_id: str, model_type: str):
|
| 1005 |
+
"""Run a test for a specific model."""
|
| 1006 |
+
print(f"\n{'='*60}")
|
| 1007 |
+
print(f"Testing {model_type}: {model_id}")
|
| 1008 |
+
print("="*60)
|
| 1009 |
+
|
| 1010 |
+
try:
|
| 1011 |
+
model = load_model(model_id)
|
| 1012 |
+
|
| 1013 |
+
embeddings = model.encode(test_texts, show_progress=False)
|
| 1014 |
+
print(f"\nEncoded {len(test_texts)} texts")
|
| 1015 |
+
print(f"Embedding shape: {embeddings.shape}")
|
| 1016 |
+
|
| 1017 |
+
similarity = np.dot(embeddings[0], embeddings[1])
|
| 1018 |
+
print(f"Cosine similarity between Hebrew and English: {similarity:.4f}")
|
| 1019 |
+
return True
|
| 1020 |
+
except Exception as e:
|
| 1021 |
+
print(f"Test failed: {e}")
|
| 1022 |
+
return False
|
| 1023 |
+
|
| 1024 |
+
# Test specific model if provided
|
| 1025 |
+
if args.model:
|
| 1026 |
+
run_model_test(args.model, "specified model")
|
| 1027 |
+
else:
|
| 1028 |
+
# Test local model
|
| 1029 |
+
if test_local:
|
| 1030 |
+
run_model_test(
|
| 1031 |
+
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
| 1032 |
+
"local sentence-transformer model"
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
# Test API models
|
| 1036 |
+
if test_remote:
|
| 1037 |
+
# Test OpenAI model
|
| 1038 |
+
if os.environ.get("OPENAI_API_KEY"):
|
| 1039 |
+
run_model_test(
|
| 1040 |
+
"openai/text-embedding-3-small",
|
| 1041 |
+
"OpenAI API model"
|
| 1042 |
+
)
|
| 1043 |
+
else:
|
| 1044 |
+
print("\n(Skipping OpenAI test - OPENAI_API_KEY not set)")
|
| 1045 |
+
|
| 1046 |
+
# Test Voyage AI model
|
| 1047 |
+
if os.environ.get("VOYAGE_API_KEY"):
|
| 1048 |
+
run_model_test(
|
| 1049 |
+
"voyage/voyage-3.5",
|
| 1050 |
+
"Voyage AI API model"
|
| 1051 |
+
)
|
| 1052 |
+
else:
|
| 1053 |
+
print("\n(Skipping Voyage AI test - VOYAGE_API_KEY not set)")
|
| 1054 |
+
|
| 1055 |
+
# Test Gemini model
|
| 1056 |
+
if os.environ.get("GEMINI_API_KEY"):
|
| 1057 |
+
run_model_test(
|
| 1058 |
+
"gemini/gemini-embedding-001",
|
| 1059 |
+
"Gemini API model"
|
| 1060 |
+
)
|
| 1061 |
+
else:
|
| 1062 |
+
print("\n(Skipping Gemini test - GEMINI_API_KEY not set)")
|
| 1063 |
+
|
remove_oversize_entries.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
One-time script to remove entries exceeding the OpenAI embedding token limit
|
| 3 |
+
from the benchmark dataset.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
# The refs to remove (from the token limit check report)
|
| 9 |
+
REFS_TO_REMOVE = [
|
| 10 |
+
"Shemot Rabbah.1:1",
|
| 11 |
+
"Bamidbar Rabbah.1:2",
|
| 12 |
+
"Bamidbar Rabbah.2:10",
|
| 13 |
+
"Shir HaShirim Rabbah.1.1:10",
|
| 14 |
+
"Eichah Rabbah.1:4",
|
| 15 |
+
"Eichah Rabbah.1:23",
|
| 16 |
+
"Eichah Rabbah.1:31",
|
| 17 |
+
"Ramban on Genesis.18:1",
|
| 18 |
+
"Ramban on Genesis.24:2",
|
| 19 |
+
"Ramban on Leviticus.1.9:1",
|
| 20 |
+
"Ramban on Numbers.16:1",
|
| 21 |
+
"Ramban on Numbers.24:1",
|
| 22 |
+
"Ramban on Deuteronomy.2.23:1",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
def main():
|
| 26 |
+
data_path = "benchmark_data/benchmark.json"
|
| 27 |
+
|
| 28 |
+
# Load the data
|
| 29 |
+
print(f"Loading data from: {data_path}")
|
| 30 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 31 |
+
data = json.load(f)
|
| 32 |
+
|
| 33 |
+
original_count = len(data)
|
| 34 |
+
print(f"Original entry count: {original_count}")
|
| 35 |
+
|
| 36 |
+
# Filter out the flagged entries
|
| 37 |
+
filtered_data = [entry for entry in data if entry["ref"] not in REFS_TO_REMOVE]
|
| 38 |
+
|
| 39 |
+
removed_count = original_count - len(filtered_data)
|
| 40 |
+
print(f"Removed {removed_count} entries")
|
| 41 |
+
print(f"New entry count: {len(filtered_data)}")
|
| 42 |
+
|
| 43 |
+
# Save the filtered data
|
| 44 |
+
print(f"Saving filtered data to: {data_path}")
|
| 45 |
+
with open(data_path, "w", encoding="utf-8") as f:
|
| 46 |
+
json.dump(filtered_data, f, ensure_ascii=False, indent=2)
|
| 47 |
+
|
| 48 |
+
print("Done!")
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space dependencies
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
transformers>=4.36.0
|
| 4 |
+
sentence-transformers>=2.2.2
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
|
| 7 |
+
# Data processing
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
pandas>=2.0.0
|
| 10 |
+
requests>=2.31.0
|
| 11 |
+
|
| 12 |
+
# Visualization
|
| 13 |
+
plotly>=5.18.0
|
| 14 |
+
|
| 15 |
+
# API-based embedding providers
|
| 16 |
+
openai>=1.0.0
|
| 17 |
+
tiktoken>=0.5.0
|
| 18 |
+
voyageai>=0.3.0
|
| 19 |
+
google-genai>=1.0.0
|
| 20 |
+
|
space_README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Rabbinic Embedding Benchmark
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Rabbinic Hebrew/Aramaic Embedding Benchmark
|
| 14 |
+
|
| 15 |
+
Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic source texts and their English translations from Sefaria.
|
| 16 |
+
|
| 17 |
+
## How It Works
|
| 18 |
+
|
| 19 |
+
Given a Hebrew/Aramaic text, can the model find its correct English translation from a pool of candidates? Models that excel at this task produce high-quality embeddings for Rabbinic literature.
|
| 20 |
+
|
| 21 |
+
## Metrics
|
| 22 |
+
|
| 23 |
+
| Metric | Description |
|
| 24 |
+
|--------|-------------|
|
| 25 |
+
| **MRR** | Mean Reciprocal Rank (average of 1/rank of correct answer) |
|
| 26 |
+
| **Recall@k** | % of queries where correct translation is in top k results |
|
| 27 |
+
| **Bitext Accuracy** | True pair vs random pair classification |
|
| 28 |
+
|
| 29 |
+
## Corpus
|
| 30 |
+
|
| 31 |
+
The benchmark includes diverse texts with English translations:
|
| 32 |
+
|
| 33 |
+
- **Talmud**: Bavli & Yerushalmi
|
| 34 |
+
- **Mishnah**: Selected tractates
|
| 35 |
+
- **Midrash**: Midrash Rabbah
|
| 36 |
+
- **Commentary**: Rashi, Ramban, Radak, Rabbeinu Behaye
|
| 37 |
+
- **Philosophy**: Guide for the Perplexed, Sefer HaIkkarim
|
| 38 |
+
- **Hasidic/Kabbalistic**: Likutei Moharan, Tomer Devorah, Kalach Pitchei Chokhmah
|
| 39 |
+
- **Mussar**: Chafetz Chaim, Kav HaYashar, Iggeret HaRamban
|
| 40 |
+
- **Halacha**: Sefer HaChinukh, Mishneh Torah
|
| 41 |
+
|
| 42 |
+
All texts sourced from [Sefaria](https://www.sefaria.org).
|
| 43 |
+
|
| 44 |
+
## API Keys
|
| 45 |
+
|
| 46 |
+
For API-based models (OpenAI, Voyage AI, Gemini), you can either:
|
| 47 |
+
- Enter your API key in the interface (not stored)
|
| 48 |
+
- Set environment variables in Space settings: `OPENAI_API_KEY`, `VOYAGE_API_KEY`, `GEMINI_API_KEY`
|