Lev Israel
Support Cohere
1787d7f
"""
Gradio interface for Rabbinic Hebrew/Aramaic Embedding Evaluation.
A Hugging Face Space for evaluating embedding models on cross-lingual
retrieval between Hebrew/Aramaic source texts and English translations.
"""
import os
import threading
from datetime import datetime
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from data_loader import load_benchmark_dataset, get_benchmark_stats
from models import (
API_MODELS,
ALL_MODELS,
load_model,
validate_model_id,
requires_api_key,
api_key_optional,
get_api_key_type,
get_api_key_env_var,
)
from evaluation import (
EvaluationResults,
evaluate_model,
evaluate_model_streaming,
compute_similarity_matrix,
get_rank_distribution,
)
from leaderboard import (
load_leaderboard as load_leaderboard_from_hub,
add_result as add_result_to_hub,
)
from jobs import (
Job,
create_job,
get_job,
update_job_progress,
complete_job,
fail_job,
delete_job,
cleanup_old_jobs,
cleanup_stale_jobs,
)
# HuggingFace Dataset ID for benchmark data
BENCHMARK_DATASET_ID = "Sefaria/Rabbinic-Hebrew-English-Pairs"
# Global state
_benchmark_data = None
def load_benchmark():
"""Load benchmark data from HuggingFace Hub, with fallback to sample data."""
global _benchmark_data
if _benchmark_data is not None:
return _benchmark_data
try:
_benchmark_data = load_benchmark_dataset(BENCHMARK_DATASET_ID)
print(f"Loaded {len(_benchmark_data)} benchmark pairs from {BENCHMARK_DATASET_ID}")
except Exception as e:
print(f"Failed to load benchmark: {e}")
print("Using sample data for testing")
# Create minimal sample data for testing
_benchmark_data = [
{
"ref": "Sample.1",
"he": "בראשית ברא אלהים את השמים ואת הארץ",
"en": "In the beginning God created the heaven and the earth",
"category": "Sample",
},
{
"ref": "Sample.2",
"he": "והארץ היתה תהו ובהו וחשך על פני תהום",
"en": "And the earth was without form, and void; and darkness was upon the face of the deep",
"category": "Sample",
},
]
return _benchmark_data
def load_leaderboard():
"""Load leaderboard from HuggingFace Hub."""
return load_leaderboard_from_hub()
def add_to_leaderboard(results: EvaluationResults):
"""Add evaluation results to leaderboard on HuggingFace Hub."""
entry = results.to_dict()
entry["timestamp"] = datetime.now().isoformat()
# Add to Hub (handles deduplication and sorting internally)
success = add_result_to_hub(entry)
if not success:
print("Note: Results saved locally but not persisted to Hub (no HF_TOKEN)")
def format_leaderboard_df():
"""Format leaderboard as pandas DataFrame for display."""
leaderboard = load_leaderboard()
if not leaderboard:
return pd.DataFrame(columns=[
"#", "Model", "MRR", "R@1", "R@5", "R@10",
"Bitext", "TrueSim", "RandSim", "N"
])
rows = []
for i, entry in enumerate(leaderboard, 1):
rows.append({
"#": i,
"Model": entry.get("model_name", entry["model_id"]),
"MRR": f"{entry['mrr']:.3f}",
"R@1": f"{entry['recall_at_1']:.1%}",
"R@5": f"{entry['recall_at_5']:.1%}",
"R@10": f"{entry['recall_at_10']:.1%}",
"Bitext": f"{entry['bitext_accuracy']:.1%}",
"TrueSim": f"{entry['avg_true_pair_similarity']:.3f}",
"RandSim": f"{entry['avg_random_pair_similarity']:.3f}",
"N": entry["num_pairs"],
})
return pd.DataFrame(rows)
def run_evaluation_background(job_id: str, model_id: str, api_key: str, max_pairs: int):
"""
Run evaluation in background thread, writing progress to job file.
This function runs in a daemon thread. It writes progress updates to a
persistent job file so that the UI can poll for status even if the
original HTTP connection times out.
"""
try:
update_job_progress(job_id, "⏳ Loading benchmark data...", 0.0)
benchmark = load_benchmark()
if max_pairs and max_pairs < len(benchmark):
benchmark = benchmark[:max_pairs]
update_job_progress(job_id, f"⏳ Loading model: {model_id}...", 0.05)
model = load_model(model_id, api_key=api_key if api_key else None)
# Progress callback for evaluation - writes to job file
def progress_callback(progress_frac: float, msg: str):
# Scale progress: model loading is 0-10%, evaluation is 10-95%, saving is 95-100%
scaled_progress = 0.10 + (progress_frac * 0.85)
update_job_progress(job_id, msg, scaled_progress)
update_job_progress(job_id, "⏳ Starting evaluation...", 0.10)
results = evaluate_model(
model,
benchmark,
batch_size=32,
progress_callback=progress_callback,
)
update_job_progress(job_id, "⏳ Saving results to leaderboard...", 0.95)
add_to_leaderboard(results)
# Format results summary
summary = f"""## Results for {results.model_name}
| Metric | Value |
|--------|-------|
| **MRR** | {results.mrr:.4f} |
| **Recall@1** | {results.recall_at_1:.1%} |
| **Recall@5** | {results.recall_at_5:.1%} |
| **Recall@10** | {results.recall_at_10:.1%} |
| **Bitext Accuracy** | {results.bitext_accuracy:.1%} |
| **Avg True Pair Sim** | {results.avg_true_pair_similarity:.4f} |
| **Avg Random Pair Sim** | {results.avg_random_pair_similarity:.4f} |
| **Pairs Evaluated** | {results.num_pairs:,} |
"""
complete_job(job_id, summary)
except Exception as e:
fail_job(job_id, str(e))
def start_evaluation(
model_choice: str,
custom_model_id: str,
api_key: str,
max_pairs: int,
):
"""
Start evaluation in background and return job ID.
Creates a persistent job file and starts a background thread.
The job file allows the UI to poll for status even if the
HTTP connection times out.
"""
# Determine which model to use
if model_choice == "custom":
model_id = custom_model_id.strip()
is_valid, error = validate_model_id(model_id)
if not is_valid:
return (
"", # job_id
f"❌ {error}",
f"❌ Invalid model ID: {error}",
format_leaderboard_df(),
gr.update(visible=False), # check_status_btn
gr.update(visible=False, value=""), # job_id_display
)
else:
model_id = model_choice
# Check if API key is required but not provided
if requires_api_key(model_id):
api_key = api_key.strip() if api_key else ""
env_var = get_api_key_env_var(model_id)
key_type = get_api_key_type(model_id)
if not api_key and not os.environ.get(env_var) and not api_key_optional(model_id):
return (
"",
"❌ API key required",
f"❌ API key required for {model_id}. Please enter your {key_type.upper()} API key or set the {env_var} environment variable.",
format_leaderboard_df(),
gr.update(visible=False),
gr.update(visible=False, value=""),
)
# Get model display name
model_name = model_id
if model_id in ALL_MODELS:
model_name = ALL_MODELS[model_id].get("name", model_id)
# Create persistent job
job = create_job(model_id=model_id, model_name=model_name, max_pairs=max_pairs)
# Start background thread
thread = threading.Thread(
target=run_evaluation_background,
args=(job.job_id, model_id, api_key, max_pairs),
daemon=True,
)
thread.start()
return (
job.job_id,
"⏳ Evaluation started! Click 'Check Status' to see progress (auto-refreshes every 5 seconds).",
"",
gr.update(), # Don't update leaderboard yet
gr.update(visible=True), # Show check_status_btn
gr.update(visible=True, value=f"Job ID: {job.job_id[:8]}..."), # Show job_id_display
)
def check_job_status(job_id: str):
"""
Check job status by reading the job file.
This is a stateless operation - each check reads fresh data from disk.
Uses regular HTTP POST (not SSE) so it survives HF Spaces proxy timeouts.
"""
if not job_id:
return (
"",
"",
gr.update(),
gr.update(visible=False), # Hide check button
gr.update(visible=False, value=""), # Hide job ID
)
job = get_job(job_id)
if job is None:
# Job not found - might have been cleaned up or never existed
return (
"⚠️ Job not found. It may have expired or been cleaned up.",
"",
gr.update(),
gr.update(visible=False),
gr.update(visible=False, value=""),
)
if job.status == "completed":
# Job completed successfully
# Clean up the job file after retrieving results
delete_job(job_id)
return (
job.progress,
job.result or "",
format_leaderboard_df(),
gr.update(visible=False), # Hide check button
gr.update(visible=False, value=""), # Hide job ID
)
elif job.status == "failed":
# Job failed
error_msg = job.error or "Unknown error"
# Clean up the job file
delete_job(job_id)
return (
job.progress,
f"❌ Error: {error_msg}",
format_leaderboard_df(),
gr.update(visible=False), # Hide check button
gr.update(visible=False, value=""), # Hide job ID
)
else:
# Still running (status is "pending" or "running")
# Include progress percentage in the status message
pct = int(job.progress_pct * 100)
progress_with_pct = f"{job.progress} ({pct}%)"
return (
progress_with_pct,
"",
gr.update(), # Don't update leaderboard yet
gr.update(visible=True), # Keep check button visible
gr.update(visible=True), # Keep job ID visible
)
def create_leaderboard_comparison():
"""Create comparison chart of all models on leaderboard."""
leaderboard = load_leaderboard()
if len(leaderboard) < 2:
return None
models = [e.get("model_name", e["model_id"]) for e in leaderboard]
mrr = [e["mrr"] for e in leaderboard]
r1 = [e["recall_at_1"] for e in leaderboard]
r5 = [e["recall_at_5"] for e in leaderboard]
r10 = [e["recall_at_10"] for e in leaderboard]
bitext = [e["bitext_accuracy"] for e in leaderboard]
fig = go.Figure()
fig.add_trace(go.Bar(name="MRR", x=models, y=mrr, marker_color="#2E86AB"))
fig.add_trace(go.Bar(name="R@1", x=models, y=r1, marker_color="#A23B72"))
fig.add_trace(go.Bar(name="R@5", x=models, y=r5, marker_color="#F18F01"))
fig.add_trace(go.Bar(name="R@10", x=models, y=r10, marker_color="#C73E1D"))
fig.add_trace(go.Bar(name="Bitext Acc", x=models, y=bitext, marker_color="#6B5B95"))
fig.update_layout(
title="Model Comparison",
yaxis_title="Score",
yaxis_range=[0, 1],
barmode="group",
template="plotly_white",
height=400,
)
return fig
def update_model_inputs_visibility(choice):
"""Show/hide custom model input and API key based on selection."""
show_custom = (choice == "custom")
show_api_key = requires_api_key(choice) if choice != "custom" else False
# Update API key label based on model type
if show_api_key:
key_type = get_api_key_type(choice)
env_var = get_api_key_env_var(choice)
is_optional = api_key_optional(choice)
if key_type == "voyage":
label = "Voyage AI API Key"
placeholder = f"Enter your Voyage AI API key (or set {env_var} env var)"
elif key_type == "gemini":
label = "Gemini API Key (optional if using gcloud)"
placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
elif key_type == "cohere":
label = "Cohere API Key"
placeholder = f"Enter your Cohere API key (or set {env_var} env var)"
else:
label = "OpenAI API Key"
placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
return (
gr.update(visible=show_custom),
gr.update(visible=show_api_key, label=label, placeholder=placeholder),
)
return (
gr.update(visible=show_custom),
gr.update(visible=show_api_key),
)
# Build the Gradio interface
def create_app():
"""Create and return the Gradio app."""
# Clean up any stale jobs from previous runs (e.g., if Space restarted mid-evaluation)
print("Cleaning up stale jobs...")
cleanup_stale_jobs(stale_minutes=30)
cleanup_old_jobs(max_age_hours=24)
# Get all model choices - custom first, then API models
model_choices = []
# Custom option first
model_choices.append(("⚙️ Custom Model (enter HuggingFace ID below)", "custom"))
# API models
for model_id, info in API_MODELS.items():
model_choices.append((f"🌐 {info['name']}", model_id))
# Load initial data
load_benchmark()
load_leaderboard()
benchmark_stats = get_benchmark_stats(_benchmark_data) if _benchmark_data else {}
with gr.Blocks(
title="Rabbinic Embedding Benchmark",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="orange",
font=gr.themes.GoogleFont("Source Sans Pro"),
),
css="""
.main-header {
text-align: center;
margin-bottom: 1rem;
}
.stats-box {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 1rem;
border-radius: 8px;
margin: 0.5rem 0;
}
""",
) as app:
# Hidden state for job tracking (persists job ID across poll requests)
job_id_state = gr.State("")
gr.Markdown(
"""
# 📚 Rabbinic Hebrew/Aramaic Embedding Benchmark
Evaluate embedding models on cross-lingual retrieval between Hebrew/Aramaic
source texts and their English translations from Sefaria.
**How it works:** Given a Hebrew/Aramaic text, can the model find its correct
English translation from a pool of candidates? Models that excel at this task
produce high-quality embeddings for Rabbinic literature.
""",
elem_classes=["main-header"],
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"""
### 📊 Benchmark Stats
- **Total Pairs:** {benchmark_stats.get('total_pairs', 'N/A'):,}
- **Categories:** {len(benchmark_stats.get('categories', {}))}
- **Avg Hebrew Length:** {benchmark_stats.get('avg_he_length', 0):.0f} chars
- **Dataset:** [View on Hugging Face](https://huggingface.co/datasets/{BENCHMARK_DATASET_ID})
""")
with gr.Column(scale=1):
gr.Markdown("""
### 📏 Metrics
- **MRR:** Mean Reciprocal Rank
- **R@k:** Recall at k (correct in top k)
- **Bitext Acc:** True vs random pair classification
""")
gr.Markdown("---")
with gr.Tabs(selected=0): # Default to Leaderboard tab
with gr.TabItem("🏆 Leaderboard"):
leaderboard_table = gr.Dataframe(
value=format_leaderboard_df(),
label="Model Rankings",
interactive=False,
)
refresh_btn = gr.Button("🔄 Refresh Leaderboard")
comparison_plot = gr.Plot(
value=create_leaderboard_comparison(),
label="Model Comparison"
)
with gr.TabItem("🔬 Evaluate Model"):
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=model_choices,
value="custom",
label="Select Model",
info="Enter a HuggingFace model ID or choose an API model",
)
custom_model_input = gr.Textbox(
label="HuggingFace Model ID",
placeholder="e.g., intfloat/multilingual-e5-large",
visible=True, # Visible by default since "custom" is selected
)
api_key_input = gr.Textbox(
label="API Key",
placeholder="Enter your API key (or set appropriate env var)",
type="password",
visible=False,
info="Required for API-based models (OpenAI, Voyage AI). Your key is not stored.",
)
total_pairs = benchmark_stats.get('total_pairs', 1000)
max_pairs_slider = gr.Slider(
minimum=100,
maximum=total_pairs,
value=total_pairs,
step=100,
label="Max Pairs to Evaluate",
info="Use fewer pairs for faster evaluation",
)
with gr.Column(scale=3):
evaluate_btn = gr.Button(
"🚀 Run Evaluation",
variant="primary",
size="lg",
)
# Manual refresh button - visible when a job is running
# This uses regular HTTP POST (not SSE) so it survives proxy timeouts
with gr.Row():
check_status_btn = gr.Button(
"🔄 Check Status",
variant="secondary",
size="sm",
visible=False,
)
job_id_display = gr.Textbox(
label="",
visible=False,
interactive=False,
container=False,
scale=2,
)
status_text = gr.Markdown("")
results_markdown = gr.Markdown("")
gr.Markdown("""
---
### About
This benchmark evaluates embedding models for Rabbinic Hebrew and Aramaic texts using
cross-lingual retrieval.
All texts and translations sourced from [Sefaria](https://www.sefaria.org).
""")
# Event handlers
model_dropdown.change(
fn=update_model_inputs_visibility,
inputs=[model_dropdown],
outputs=[custom_model_input, api_key_input],
)
# Start evaluation: creates persistent job file and spawns background thread.
# Returns immediately with job_id so UI doesn't timeout waiting.
evaluate_btn.click(
fn=start_evaluation,
inputs=[model_dropdown, custom_model_input, api_key_input, max_pairs_slider],
outputs=[job_id_state, status_text, results_markdown, leaderboard_table, check_status_btn, job_id_display],
).then(
# Start JavaScript auto-polling after evaluation begins
fn=None,
inputs=None,
outputs=None,
js="""
() => {
// Clear any existing interval
if (window.jobPollInterval) {
clearInterval(window.jobPollInterval);
}
console.log('[Auto-poll] Starting polling every 5 seconds');
// Auto-click the check status button every 5 seconds
window.jobPollInterval = setInterval(() => {
// Find button by looking for "Check Status" text
const buttons = document.querySelectorAll('button');
let checkBtn = null;
for (const btn of buttons) {
if (btn.textContent.includes('Check Status')) {
checkBtn = btn;
break;
}
}
if (checkBtn && checkBtn.offsetParent !== null) {
console.log('[Auto-poll] Clicking Check Status button');
checkBtn.click();
} else {
// Button is hidden (job done), stop polling
console.log('[Auto-poll] Button not visible, stopping');
clearInterval(window.jobPollInterval);
window.jobPollInterval = null;
}
}, 5000);
}
""",
)
# Check status button - uses regular HTTP POST (not SSE) so it survives proxy timeouts
check_status_btn.click(
fn=check_job_status,
inputs=[job_id_state],
outputs=[status_text, results_markdown, leaderboard_table, check_status_btn, job_id_display],
)
def refresh_leaderboard():
"""Force refresh leaderboard from Hub."""
from leaderboard import clear_cache
clear_cache() # Clear cache to force fresh load
return (format_leaderboard_df(), create_leaderboard_comparison())
refresh_btn.click(
fn=refresh_leaderboard,
outputs=[leaderboard_table, comparison_plot],
)
return app
# Main entry point
if __name__ == "__main__":
app = create_app()
app.queue()
app.launch()