|
|
|
|
|
""" |
|
|
Script to build the benchmark dataset from Sefaria API. |
|
|
|
|
|
Run this script to fetch and cache parallel Hebrew/Aramaic-English text pairs |
|
|
from Sefaria for use in the embedding evaluation benchmark. |
|
|
|
|
|
Usage: |
|
|
python build_benchmark.py [--max-per-text N] [--total N] [--output PATH] |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import requests |
|
|
|
|
|
from data_loader import ( |
|
|
build_benchmark_dataset, |
|
|
get_benchmark_stats, |
|
|
get_index_from_sefaria, |
|
|
set_sefaria_host, |
|
|
get_sefaria_host, |
|
|
BENCHMARK_TEXTS, |
|
|
) |
|
|
|
|
|
|
|
|
def get_name_suggestions(title: str, host: str, limit: int = 5) -> list[str]: |
|
|
"""Get name suggestions from the Sefaria name API.""" |
|
|
try: |
|
|
url = f"{host}/api/name/{title}" |
|
|
response = requests.get(url, params={"limit": limit, "type": "ref"}, timeout=10) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
|
|
|
completions = data.get("completions", []) |
|
|
return completions[:limit] |
|
|
except requests.RequestException: |
|
|
pass |
|
|
return [] |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Build Rabbinic embedding benchmark dataset from Sefaria" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-per-text", |
|
|
type=int, |
|
|
default=40, |
|
|
help="Maximum segments per text (default: 40)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--total", |
|
|
type=int, |
|
|
default=10000, |
|
|
help="Total target segments (default: 10000)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="benchmark_data/benchmark.json", |
|
|
help="Output file path (default: benchmark_data/benchmark.json)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dry-run", |
|
|
action="store_true", |
|
|
help="Show what would be fetched without making API calls", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--host", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Sefaria host URL (default: https://www.sefaria.org, or SEFARIA_HOST env var)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--check-titles", |
|
|
action="store_true", |
|
|
help="Check all text titles against the API to verify they exist", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.host: |
|
|
set_sefaria_host(args.host) |
|
|
|
|
|
if args.check_titles: |
|
|
print("="*60) |
|
|
print("Checking Text Titles Against API") |
|
|
print("="*60) |
|
|
host = get_sefaria_host() |
|
|
print(f"\nSefaria host: {host}\n") |
|
|
|
|
|
valid = [] |
|
|
invalid = [] |
|
|
suggestions = {} |
|
|
|
|
|
for category_key, category_info in BENCHMARK_TEXTS.items(): |
|
|
category_name = category_info["category"] |
|
|
print(f"\n{category_name}:") |
|
|
|
|
|
for text in category_info["texts"]: |
|
|
index = get_index_from_sefaria(text) |
|
|
if index: |
|
|
print(f" β {text}") |
|
|
valid.append(text) |
|
|
else: |
|
|
|
|
|
suggested = get_name_suggestions(text, host) |
|
|
suggestions[text] = suggested |
|
|
if suggested: |
|
|
print(f" β {text} β Did you mean: {suggested[0]}?") |
|
|
else: |
|
|
print(f" β {text}") |
|
|
invalid.append(text) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("SUMMARY") |
|
|
print("="*60) |
|
|
print(f"\nValid titles: {len(valid)}") |
|
|
print(f"Invalid titles: {len(invalid)}") |
|
|
|
|
|
if invalid: |
|
|
print(f"\nInvalid titles that need fixing:") |
|
|
for title in invalid: |
|
|
suggested = suggestions.get(title, []) |
|
|
if suggested: |
|
|
print(f" - {title}") |
|
|
print(f" Suggestions: {', '.join(suggested[:3])}") |
|
|
else: |
|
|
print(f" - {title} (no suggestions found)") |
|
|
else: |
|
|
print("\nAll titles are valid!") |
|
|
|
|
|
return |
|
|
|
|
|
if args.dry_run: |
|
|
print("DRY RUN: Would fetch from these texts:\n") |
|
|
print(f"Sefaria host: {get_sefaria_host()}") |
|
|
total_texts = 0 |
|
|
for category_key, category_info in BENCHMARK_TEXTS.items(): |
|
|
print(f"\n{category_info['category']} ({category_info['language']}):") |
|
|
for text in category_info["texts"]: |
|
|
print(f" - {text}") |
|
|
total_texts += 1 |
|
|
print(f"\nTotal texts: {total_texts}") |
|
|
print(f"Target segments per text: {args.max_per_text}") |
|
|
print(f"Total target segments: {args.total}") |
|
|
return |
|
|
|
|
|
print("="*60) |
|
|
print("Building Rabbinic Embedding Benchmark Dataset") |
|
|
print("="*60) |
|
|
print(f"\nSettings:") |
|
|
print(f" Sefaria host: {get_sefaria_host()}") |
|
|
print(f" Max segments per text: {args.max_per_text}") |
|
|
print(f" Total target: {args.total}") |
|
|
print(f" Output: {args.output}") |
|
|
print() |
|
|
|
|
|
|
|
|
Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
pairs = build_benchmark_dataset( |
|
|
output_path=args.output, |
|
|
segments_per_text=args.max_per_text, |
|
|
total_target=args.total, |
|
|
) |
|
|
|
|
|
|
|
|
stats = get_benchmark_stats(pairs) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("BENCHMARK COMPLETE") |
|
|
print("="*60) |
|
|
print(f"\nFinal Statistics:") |
|
|
print(f" Total pairs: {stats['total_pairs']:,}") |
|
|
print(f" Categories:") |
|
|
for cat, count in sorted(stats["categories"].items()): |
|
|
print(f" - {cat}: {count:,}") |
|
|
print(f" Average Hebrew text length: {stats['avg_he_length']:.0f} chars") |
|
|
print(f" Average English text length: {stats['avg_en_length']:.0f} chars") |
|
|
print(f"\nSaved to: {args.output}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|