""" Gradio interface for Rabbinic Hebrew/Aramaic Embedding Evaluation. A Hugging Face Space for evaluating embedding models on cross-lingual retrieval between Hebrew/Aramaic source texts and English translations. """ import os import threading from datetime import datetime import gradio as gr import pandas as pd import plotly.graph_objects as go from data_loader import load_benchmark_dataset, get_benchmark_stats from models import ( API_MODELS, ALL_MODELS, load_model, validate_model_id, requires_api_key, api_key_optional, get_api_key_type, get_api_key_env_var, ) from evaluation import ( EvaluationResults, evaluate_model, evaluate_model_streaming, compute_similarity_matrix, get_rank_distribution, ) from leaderboard import ( load_leaderboard as load_leaderboard_from_hub, add_result as add_result_to_hub, ) from jobs import ( Job, create_job, get_job, update_job_progress, complete_job, fail_job, delete_job, cleanup_old_jobs, cleanup_stale_jobs, ) # HuggingFace Dataset ID for benchmark data BENCHMARK_DATASET_ID = "Sefaria/Rabbinic-Hebrew-English-Pairs" # Global state _benchmark_data = None def load_benchmark(): """Load benchmark data from HuggingFace Hub, with fallback to sample data.""" global _benchmark_data if _benchmark_data is not None: return _benchmark_data try: _benchmark_data = load_benchmark_dataset(BENCHMARK_DATASET_ID) print(f"Loaded {len(_benchmark_data)} benchmark pairs from {BENCHMARK_DATASET_ID}") except Exception as e: print(f"Failed to load benchmark: {e}") print("Using sample data for testing") # Create minimal sample data for testing _benchmark_data = [ { "ref": "Sample.1", "he": "בראשית ברא אלהים את השמים ואת הארץ", "en": "In the beginning God created the heaven and the earth", "category": "Sample", }, { "ref": "Sample.2", "he": "והארץ היתה תהו ובהו וחשך על פני תהום", "en": "And the earth was without form, and void; and darkness was upon the face of the deep", "category": "Sample", }, ] return _benchmark_data def load_leaderboard(): """Load leaderboard from HuggingFace Hub.""" return load_leaderboard_from_hub() def add_to_leaderboard(results: EvaluationResults): """Add evaluation results to leaderboard on HuggingFace Hub.""" entry = results.to_dict() entry["timestamp"] = datetime.now().isoformat() # Add to Hub (handles deduplication and sorting internally) success = add_result_to_hub(entry) if not success: print("Note: Results saved locally but not persisted to Hub (no HF_TOKEN)") def format_leaderboard_df(): """Format leaderboard as pandas DataFrame for display.""" leaderboard = load_leaderboard() if not leaderboard: return pd.DataFrame(columns=[ "#", "Model", "MRR", "R@1", "R@5", "R@10", "Bitext", "TrueSim", "RandSim", "N" ]) rows = [] for i, entry in enumerate(leaderboard, 1): rows.append({ "#": i, "Model": entry.get("model_name", entry["model_id"]), "MRR": f"{entry['mrr']:.3f}", "R@1": f"{entry['recall_at_1']:.1%}", "R@5": f"{entry['recall_at_5']:.1%}", "R@10": f"{entry['recall_at_10']:.1%}", "Bitext": f"{entry['bitext_accuracy']:.1%}", "TrueSim": f"{entry['avg_true_pair_similarity']:.3f}", "RandSim": f"{entry['avg_random_pair_similarity']:.3f}", "N": entry["num_pairs"], }) return pd.DataFrame(rows) def run_evaluation_background(job_id: str, model_id: str, api_key: str, max_pairs: int): """ Run evaluation in background thread, writing progress to job file. This function runs in a daemon thread. It writes progress updates to a persistent job file so that the UI can poll for status even if the original HTTP connection times out. """ try: update_job_progress(job_id, "⏳ Loading benchmark data...", 0.0) benchmark = load_benchmark() if max_pairs and max_pairs < len(benchmark): benchmark = benchmark[:max_pairs] update_job_progress(job_id, f"⏳ Loading model: {model_id}...", 0.05) model = load_model(model_id, api_key=api_key if api_key else None) # Progress callback for evaluation - writes to job file def progress_callback(progress_frac: float, msg: str): # Scale progress: model loading is 0-10%, evaluation is 10-95%, saving is 95-100% scaled_progress = 0.10 + (progress_frac * 0.85) update_job_progress(job_id, msg, scaled_progress) update_job_progress(job_id, "⏳ Starting evaluation...", 0.10) results = evaluate_model( model, benchmark, batch_size=32, progress_callback=progress_callback, ) update_job_progress(job_id, "⏳ Saving results to leaderboard...", 0.95) add_to_leaderboard(results) # Format results summary summary = f"""## Results for {results.model_name} | Metric | Value | |--------|-------| | **MRR** | {results.mrr:.4f} | | **Recall@1** | {results.recall_at_1:.1%} | | **Recall@5** | {results.recall_at_5:.1%} | | **Recall@10** | {results.recall_at_10:.1%} | | **Bitext Accuracy** | {results.bitext_accuracy:.1%} | | **Avg True Pair Sim** | {results.avg_true_pair_similarity:.4f} | | **Avg Random Pair Sim** | {results.avg_random_pair_similarity:.4f} | | **Pairs Evaluated** | {results.num_pairs:,} | """ complete_job(job_id, summary) except Exception as e: fail_job(job_id, str(e)) def start_evaluation( model_choice: str, custom_model_id: str, api_key: str, max_pairs: int, ): """ Start evaluation in background and return job ID. Creates a persistent job file and starts a background thread. The job file allows the UI to poll for status even if the HTTP connection times out. """ # Determine which model to use if model_choice == "custom": model_id = custom_model_id.strip() is_valid, error = validate_model_id(model_id) if not is_valid: return ( "", # job_id f"❌ {error}", f"❌ Invalid model ID: {error}", format_leaderboard_df(), gr.update(visible=False), # check_status_btn gr.update(visible=False, value=""), # job_id_display ) else: model_id = model_choice # Check if API key is required but not provided if requires_api_key(model_id): api_key = api_key.strip() if api_key else "" env_var = get_api_key_env_var(model_id) key_type = get_api_key_type(model_id) if not api_key and not os.environ.get(env_var) and not api_key_optional(model_id): return ( "", "❌ API key required", f"❌ API key required for {model_id}. Please enter your {key_type.upper()} API key or set the {env_var} environment variable.", format_leaderboard_df(), gr.update(visible=False), gr.update(visible=False, value=""), ) # Get model display name model_name = model_id if model_id in ALL_MODELS: model_name = ALL_MODELS[model_id].get("name", model_id) # Create persistent job job = create_job(model_id=model_id, model_name=model_name, max_pairs=max_pairs) # Start background thread thread = threading.Thread( target=run_evaluation_background, args=(job.job_id, model_id, api_key, max_pairs), daemon=True, ) thread.start() return ( job.job_id, "⏳ Evaluation started! Click 'Check Status' to see progress (auto-refreshes every 5 seconds).", "", gr.update(), # Don't update leaderboard yet gr.update(visible=True), # Show check_status_btn gr.update(visible=True, value=f"Job ID: {job.job_id[:8]}..."), # Show job_id_display ) def check_job_status(job_id: str): """ Check job status by reading the job file. This is a stateless operation - each check reads fresh data from disk. Uses regular HTTP POST (not SSE) so it survives HF Spaces proxy timeouts. """ if not job_id: return ( "", "", gr.update(), gr.update(visible=False), # Hide check button gr.update(visible=False, value=""), # Hide job ID ) job = get_job(job_id) if job is None: # Job not found - might have been cleaned up or never existed return ( "⚠️ Job not found. It may have expired or been cleaned up.", "", gr.update(), gr.update(visible=False), gr.update(visible=False, value=""), ) if job.status == "completed": # Job completed successfully # Clean up the job file after retrieving results delete_job(job_id) return ( job.progress, job.result or "", format_leaderboard_df(), gr.update(visible=False), # Hide check button gr.update(visible=False, value=""), # Hide job ID ) elif job.status == "failed": # Job failed error_msg = job.error or "Unknown error" # Clean up the job file delete_job(job_id) return ( job.progress, f"❌ Error: {error_msg}", format_leaderboard_df(), gr.update(visible=False), # Hide check button gr.update(visible=False, value=""), # Hide job ID ) else: # Still running (status is "pending" or "running") # Include progress percentage in the status message pct = int(job.progress_pct * 100) progress_with_pct = f"{job.progress} ({pct}%)" return ( progress_with_pct, "", gr.update(), # Don't update leaderboard yet gr.update(visible=True), # Keep check button visible gr.update(visible=True), # Keep job ID visible ) def create_leaderboard_comparison(): """Create comparison chart of all models on leaderboard.""" leaderboard = load_leaderboard() if len(leaderboard) < 2: return None models = [e.get("model_name", e["model_id"]) for e in leaderboard] mrr = [e["mrr"] for e in leaderboard] r1 = [e["recall_at_1"] for e in leaderboard] r5 = [e["recall_at_5"] for e in leaderboard] r10 = [e["recall_at_10"] for e in leaderboard] bitext = [e["bitext_accuracy"] for e in leaderboard] fig = go.Figure() fig.add_trace(go.Bar(name="MRR", x=models, y=mrr, marker_color="#2E86AB")) fig.add_trace(go.Bar(name="R@1", x=models, y=r1, marker_color="#A23B72")) fig.add_trace(go.Bar(name="R@5", x=models, y=r5, marker_color="#F18F01")) fig.add_trace(go.Bar(name="R@10", x=models, y=r10, marker_color="#C73E1D")) fig.add_trace(go.Bar(name="Bitext Acc", x=models, y=bitext, marker_color="#6B5B95")) fig.update_layout( title="Model Comparison", yaxis_title="Score", yaxis_range=[0, 1], barmode="group", template="plotly_white", height=400, ) return fig def update_model_inputs_visibility(choice): """Show/hide custom model input and API key based on selection.""" show_custom = (choice == "custom") show_api_key = requires_api_key(choice) if choice != "custom" else False # Update API key label based on model type if show_api_key: key_type = get_api_key_type(choice) env_var = get_api_key_env_var(choice) is_optional = api_key_optional(choice) if key_type == "voyage": label = "Voyage AI API Key" placeholder = f"Enter your Voyage AI API key (or set {env_var} env var)" elif key_type == "gemini": label = "Gemini API Key (optional if using gcloud)" placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}" elif key_type == "cohere": label = "Cohere API Key" placeholder = f"Enter your Cohere API key (or set {env_var} env var)" else: label = "OpenAI API Key" placeholder = f"Enter your OpenAI API key (or set {env_var} env var)" return ( gr.update(visible=show_custom), gr.update(visible=show_api_key, label=label, placeholder=placeholder), ) return ( gr.update(visible=show_custom), gr.update(visible=show_api_key), ) # Build the Gradio interface def create_app(): """Create and return the Gradio app.""" # Clean up any stale jobs from previous runs (e.g., if Space restarted mid-evaluation) print("Cleaning up stale jobs...") cleanup_stale_jobs(stale_minutes=30) cleanup_old_jobs(max_age_hours=24) # Get all model choices - custom first, then API models model_choices = [] # Custom option first model_choices.append(("⚙️ Custom Model (enter HuggingFace ID below)", "custom")) # API models for model_id, info in API_MODELS.items(): model_choices.append((f"🌐 {info['name']}", model_id)) # Load initial data load_benchmark() load_leaderboard() benchmark_stats = get_benchmark_stats(_benchmark_data) if _benchmark_data else {} with gr.Blocks( title="Rabbinic Embedding Benchmark", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="orange", font=gr.themes.GoogleFont("Source Sans Pro"), ), css=""" .main-header { text-align: center; margin-bottom: 1rem; } .stats-box { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 1rem; border-radius: 8px; margin: 0.5rem 0; } """, ) as app: # Hidden state for job tracking (persists job ID across poll requests) job_id_state = gr.State("") gr.Markdown( """ # 📚 Rabbinic Hebrew/Aramaic Embedding Benchmark Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic source texts and their English translations from Sefaria. **How it works:** Given a Hebrew/Aramaic text, can the model find its correct English translation from a pool of candidates? Models that excel at this task produce high-quality embeddings for Rabbinic literature. """, elem_classes=["main-header"], ) with gr.Row(): with gr.Column(scale=1): gr.Markdown(f""" ### 📊 Benchmark Stats - **Total Pairs:** {benchmark_stats.get('total_pairs', 'N/A'):,} - **Categories:** {len(benchmark_stats.get('categories', {}))} - **Avg Hebrew Length:** {benchmark_stats.get('avg_he_length', 0):.0f} chars - **Dataset:** [View on Hugging Face](https://huggingface.co/datasets/{BENCHMARK_DATASET_ID}) """) with gr.Column(scale=1): gr.Markdown(""" ### 📏 Metrics - **MRR:** Mean Reciprocal Rank - **R@k:** Recall at k (correct in top k) - **Bitext Acc:** True vs random pair classification """) gr.Markdown("---") with gr.Tabs(selected=0): # Default to Leaderboard tab with gr.TabItem("🏆 Leaderboard"): leaderboard_table = gr.Dataframe( value=format_leaderboard_df(), label="Model Rankings", interactive=False, ) refresh_btn = gr.Button("🔄 Refresh Leaderboard") comparison_plot = gr.Plot( value=create_leaderboard_comparison(), label="Model Comparison" ) with gr.TabItem("🔬 Evaluate Model"): with gr.Row(): with gr.Column(scale=2): model_dropdown = gr.Dropdown( choices=model_choices, value="custom", label="Select Model", info="Enter a HuggingFace model ID or choose an API model", ) custom_model_input = gr.Textbox( label="HuggingFace Model ID", placeholder="e.g., intfloat/multilingual-e5-large", visible=True, # Visible by default since "custom" is selected ) api_key_input = gr.Textbox( label="API Key", placeholder="Enter your API key (or set appropriate env var)", type="password", visible=False, info="Required for API-based models (OpenAI, Voyage AI). Your key is not stored.", ) total_pairs = benchmark_stats.get('total_pairs', 1000) max_pairs_slider = gr.Slider( minimum=100, maximum=total_pairs, value=total_pairs, step=100, label="Max Pairs to Evaluate", info="Use fewer pairs for faster evaluation", ) with gr.Column(scale=3): evaluate_btn = gr.Button( "🚀 Run Evaluation", variant="primary", size="lg", ) # Manual refresh button - visible when a job is running # This uses regular HTTP POST (not SSE) so it survives proxy timeouts with gr.Row(): check_status_btn = gr.Button( "🔄 Check Status", variant="secondary", size="sm", visible=False, ) job_id_display = gr.Textbox( label="", visible=False, interactive=False, container=False, scale=2, ) status_text = gr.Markdown("") results_markdown = gr.Markdown("") gr.Markdown(""" --- ### About This benchmark evaluates embedding models for Rabbinic Hebrew and Aramaic texts using cross-lingual retrieval. All texts and translations sourced from [Sefaria](https://www.sefaria.org). """) # Event handlers model_dropdown.change( fn=update_model_inputs_visibility, inputs=[model_dropdown], outputs=[custom_model_input, api_key_input], ) # Start evaluation: creates persistent job file and spawns background thread. # Returns immediately with job_id so UI doesn't timeout waiting. evaluate_btn.click( fn=start_evaluation, inputs=[model_dropdown, custom_model_input, api_key_input, max_pairs_slider], outputs=[job_id_state, status_text, results_markdown, leaderboard_table, check_status_btn, job_id_display], ).then( # Start JavaScript auto-polling after evaluation begins fn=None, inputs=None, outputs=None, js=""" () => { // Clear any existing interval if (window.jobPollInterval) { clearInterval(window.jobPollInterval); } console.log('[Auto-poll] Starting polling every 5 seconds'); // Auto-click the check status button every 5 seconds window.jobPollInterval = setInterval(() => { // Find button by looking for "Check Status" text const buttons = document.querySelectorAll('button'); let checkBtn = null; for (const btn of buttons) { if (btn.textContent.includes('Check Status')) { checkBtn = btn; break; } } if (checkBtn && checkBtn.offsetParent !== null) { console.log('[Auto-poll] Clicking Check Status button'); checkBtn.click(); } else { // Button is hidden (job done), stop polling console.log('[Auto-poll] Button not visible, stopping'); clearInterval(window.jobPollInterval); window.jobPollInterval = null; } }, 5000); } """, ) # Check status button - uses regular HTTP POST (not SSE) so it survives proxy timeouts check_status_btn.click( fn=check_job_status, inputs=[job_id_state], outputs=[status_text, results_markdown, leaderboard_table, check_status_btn, job_id_display], ) def refresh_leaderboard(): """Force refresh leaderboard from Hub.""" from leaderboard import clear_cache clear_cache() # Clear cache to force fresh load return (format_leaderboard_df(), create_leaderboard_comparison()) refresh_btn.click( fn=refresh_leaderboard, outputs=[leaderboard_table, comparison_plot], ) return app # Main entry point if __name__ == "__main__": app = create_app() app.queue() app.launch()