#!/usr/bin/env python3 """ Script to build the benchmark dataset from Sefaria API. Run this script to fetch and cache parallel Hebrew/Aramaic-English text pairs from Sefaria for use in the embedding evaluation benchmark. Usage: python build_benchmark.py [--max-per-text N] [--total N] [--output PATH] """ import argparse import sys from pathlib import Path import requests from data_loader import ( build_benchmark_dataset, get_benchmark_stats, get_index_from_sefaria, set_sefaria_host, get_sefaria_host, BENCHMARK_TEXTS, ) def get_name_suggestions(title: str, host: str, limit: int = 5) -> list[str]: """Get name suggestions from the Sefaria name API.""" try: url = f"{host}/api/name/{title}" response = requests.get(url, params={"limit": limit, "type": "ref"}, timeout=10) if response.status_code == 200: data = response.json() # Return completions that are refs (book titles) completions = data.get("completions", []) return completions[:limit] except requests.RequestException: pass return [] def main(): parser = argparse.ArgumentParser( description="Build Rabbinic embedding benchmark dataset from Sefaria" ) parser.add_argument( "--max-per-text", type=int, default=40, help="Maximum segments per text (default: 40)", ) parser.add_argument( "--total", type=int, default=10000, help="Total target segments (default: 10000)", ) parser.add_argument( "--output", type=str, default="benchmark_data/benchmark.json", help="Output file path (default: benchmark_data/benchmark.json)", ) parser.add_argument( "--dry-run", action="store_true", help="Show what would be fetched without making API calls", ) parser.add_argument( "--host", type=str, default=None, help="Sefaria host URL (default: https://www.sefaria.org, or SEFARIA_HOST env var)", ) parser.add_argument( "--check-titles", action="store_true", help="Check all text titles against the API to verify they exist", ) args = parser.parse_args() # Configure Sefaria host if specified if args.host: set_sefaria_host(args.host) if args.check_titles: print("="*60) print("Checking Text Titles Against API") print("="*60) host = get_sefaria_host() print(f"\nSefaria host: {host}\n") valid = [] invalid = [] suggestions = {} for category_key, category_info in BENCHMARK_TEXTS.items(): category_name = category_info["category"] print(f"\n{category_name}:") for text in category_info["texts"]: index = get_index_from_sefaria(text) if index: print(f" ✓ {text}") valid.append(text) else: # Get suggestions from name API suggested = get_name_suggestions(text, host) suggestions[text] = suggested if suggested: print(f" ✗ {text} → Did you mean: {suggested[0]}?") else: print(f" ✗ {text}") invalid.append(text) print("\n" + "="*60) print("SUMMARY") print("="*60) print(f"\nValid titles: {len(valid)}") print(f"Invalid titles: {len(invalid)}") if invalid: print(f"\nInvalid titles that need fixing:") for title in invalid: suggested = suggestions.get(title, []) if suggested: print(f" - {title}") print(f" Suggestions: {', '.join(suggested[:3])}") else: print(f" - {title} (no suggestions found)") else: print("\nAll titles are valid!") return if args.dry_run: print("DRY RUN: Would fetch from these texts:\n") print(f"Sefaria host: {get_sefaria_host()}") total_texts = 0 for category_key, category_info in BENCHMARK_TEXTS.items(): print(f"\n{category_info['category']} ({category_info['language']}):") for text in category_info["texts"]: print(f" - {text}") total_texts += 1 print(f"\nTotal texts: {total_texts}") print(f"Target segments per text: {args.max_per_text}") print(f"Total target segments: {args.total}") return print("="*60) print("Building Rabbinic Embedding Benchmark Dataset") print("="*60) print(f"\nSettings:") print(f" Sefaria host: {get_sefaria_host()}") print(f" Max segments per text: {args.max_per_text}") print(f" Total target: {args.total}") print(f" Output: {args.output}") print() # Ensure output directory exists Path(args.output).parent.mkdir(parents=True, exist_ok=True) # Build the dataset pairs = build_benchmark_dataset( output_path=args.output, segments_per_text=args.max_per_text, total_target=args.total, ) # Print final statistics stats = get_benchmark_stats(pairs) print("\n" + "="*60) print("BENCHMARK COMPLETE") print("="*60) print(f"\nFinal Statistics:") print(f" Total pairs: {stats['total_pairs']:,}") print(f" Categories:") for cat, count in sorted(stats["categories"].items()): print(f" - {cat}: {count:,}") print(f" Average Hebrew text length: {stats['avg_he_length']:.0f} chars") print(f" Average English text length: {stats['avg_en_length']:.0f} chars") print(f"\nSaved to: {args.output}") if __name__ == "__main__": main()