Rabbinic-Embedding-Bench / build_benchmark.py
Lev Israel
Initial Commit
018c4c5
#!/usr/bin/env python3
"""
Script to build the benchmark dataset from Sefaria API.
Run this script to fetch and cache parallel Hebrew/Aramaic-English text pairs
from Sefaria for use in the embedding evaluation benchmark.
Usage:
python build_benchmark.py [--max-per-text N] [--total N] [--output PATH]
"""
import argparse
import sys
from pathlib import Path
import requests
from data_loader import (
build_benchmark_dataset,
get_benchmark_stats,
get_index_from_sefaria,
set_sefaria_host,
get_sefaria_host,
BENCHMARK_TEXTS,
)
def get_name_suggestions(title: str, host: str, limit: int = 5) -> list[str]:
"""Get name suggestions from the Sefaria name API."""
try:
url = f"{host}/api/name/{title}"
response = requests.get(url, params={"limit": limit, "type": "ref"}, timeout=10)
if response.status_code == 200:
data = response.json()
# Return completions that are refs (book titles)
completions = data.get("completions", [])
return completions[:limit]
except requests.RequestException:
pass
return []
def main():
parser = argparse.ArgumentParser(
description="Build Rabbinic embedding benchmark dataset from Sefaria"
)
parser.add_argument(
"--max-per-text",
type=int,
default=40,
help="Maximum segments per text (default: 40)",
)
parser.add_argument(
"--total",
type=int,
default=10000,
help="Total target segments (default: 10000)",
)
parser.add_argument(
"--output",
type=str,
default="benchmark_data/benchmark.json",
help="Output file path (default: benchmark_data/benchmark.json)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be fetched without making API calls",
)
parser.add_argument(
"--host",
type=str,
default=None,
help="Sefaria host URL (default: https://www.sefaria.org, or SEFARIA_HOST env var)",
)
parser.add_argument(
"--check-titles",
action="store_true",
help="Check all text titles against the API to verify they exist",
)
args = parser.parse_args()
# Configure Sefaria host if specified
if args.host:
set_sefaria_host(args.host)
if args.check_titles:
print("="*60)
print("Checking Text Titles Against API")
print("="*60)
host = get_sefaria_host()
print(f"\nSefaria host: {host}\n")
valid = []
invalid = []
suggestions = {}
for category_key, category_info in BENCHMARK_TEXTS.items():
category_name = category_info["category"]
print(f"\n{category_name}:")
for text in category_info["texts"]:
index = get_index_from_sefaria(text)
if index:
print(f" βœ“ {text}")
valid.append(text)
else:
# Get suggestions from name API
suggested = get_name_suggestions(text, host)
suggestions[text] = suggested
if suggested:
print(f" βœ— {text} β†’ Did you mean: {suggested[0]}?")
else:
print(f" βœ— {text}")
invalid.append(text)
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"\nValid titles: {len(valid)}")
print(f"Invalid titles: {len(invalid)}")
if invalid:
print(f"\nInvalid titles that need fixing:")
for title in invalid:
suggested = suggestions.get(title, [])
if suggested:
print(f" - {title}")
print(f" Suggestions: {', '.join(suggested[:3])}")
else:
print(f" - {title} (no suggestions found)")
else:
print("\nAll titles are valid!")
return
if args.dry_run:
print("DRY RUN: Would fetch from these texts:\n")
print(f"Sefaria host: {get_sefaria_host()}")
total_texts = 0
for category_key, category_info in BENCHMARK_TEXTS.items():
print(f"\n{category_info['category']} ({category_info['language']}):")
for text in category_info["texts"]:
print(f" - {text}")
total_texts += 1
print(f"\nTotal texts: {total_texts}")
print(f"Target segments per text: {args.max_per_text}")
print(f"Total target segments: {args.total}")
return
print("="*60)
print("Building Rabbinic Embedding Benchmark Dataset")
print("="*60)
print(f"\nSettings:")
print(f" Sefaria host: {get_sefaria_host()}")
print(f" Max segments per text: {args.max_per_text}")
print(f" Total target: {args.total}")
print(f" Output: {args.output}")
print()
# Ensure output directory exists
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
# Build the dataset
pairs = build_benchmark_dataset(
output_path=args.output,
segments_per_text=args.max_per_text,
total_target=args.total,
)
# Print final statistics
stats = get_benchmark_stats(pairs)
print("\n" + "="*60)
print("BENCHMARK COMPLETE")
print("="*60)
print(f"\nFinal Statistics:")
print(f" Total pairs: {stats['total_pairs']:,}")
print(f" Categories:")
for cat, count in sorted(stats["categories"].items()):
print(f" - {cat}: {count:,}")
print(f" Average Hebrew text length: {stats['avg_he_length']:.0f} chars")
print(f" Average English text length: {stats['avg_en_length']:.0f} chars")
print(f"\nSaved to: {args.output}")
if __name__ == "__main__":
main()