Rabbinic-Embedding-Bench / evaluation.py
Lev Israel
Refactor to use gr.Progress API and upgrade to Gradio 5
1a6f495
"""
Cross-lingual retrieval evaluation for Rabbinic embedding benchmark.
Computes retrieval metrics to measure how well embedding models align
Hebrew/Aramaic source texts with their English translations.
"""
from dataclasses import dataclass
from typing import Optional
import numpy as np
@dataclass
class EvaluationResults:
"""Container for evaluation results."""
model_id: str
model_name: str
# Core retrieval metrics
recall_at_1: float
recall_at_5: float
recall_at_10: float
mrr: float # Mean Reciprocal Rank
# Additional metrics
bitext_accuracy: float # True pair vs random pair classification
avg_true_pair_similarity: float
avg_random_pair_similarity: float
# Metadata
num_pairs: int
categories: dict[str, int]
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"model_id": self.model_id,
"model_name": self.model_name,
"recall_at_1": self.recall_at_1,
"recall_at_5": self.recall_at_5,
"recall_at_10": self.recall_at_10,
"mrr": self.mrr,
"bitext_accuracy": self.bitext_accuracy,
"avg_true_pair_similarity": self.avg_true_pair_similarity,
"avg_random_pair_similarity": self.avg_random_pair_similarity,
"num_pairs": self.num_pairs,
"categories": self.categories,
}
@classmethod
def from_dict(cls, data: dict) -> "EvaluationResults":
"""Create from dictionary."""
return cls(**data)
def compute_similarity_matrix(
query_embeddings: np.ndarray,
passage_embeddings: np.ndarray,
) -> np.ndarray:
"""
Compute cosine similarity matrix between queries and passages.
Assumes embeddings are already L2-normalized.
Args:
query_embeddings: (N, D) array of query embeddings
passage_embeddings: (M, D) array of passage embeddings
Returns:
(N, M) similarity matrix
"""
return np.dot(query_embeddings, passage_embeddings.T)
def compute_retrieval_metrics(
similarity_matrix: np.ndarray,
k_values: list[int] = [1, 5, 10],
) -> dict[str, float]:
"""
Compute retrieval metrics from similarity matrix.
Assumes the correct match for query i is passage i (diagonal).
Args:
similarity_matrix: (N, N) similarity matrix where diagonal is true matches
k_values: List of k values for Recall@k
Returns:
Dict with recall@k and mrr values
"""
n = similarity_matrix.shape[0]
# Get rankings for each query
# Negate to sort descending (highest similarity first)
rankings = np.argsort(-similarity_matrix, axis=1)
# Find rank of true match (diagonal) for each query
true_ranks = np.zeros(n, dtype=int)
for i in range(n):
# Find position of index i in the ranking for query i
true_ranks[i] = np.where(rankings[i] == i)[0][0]
results = {}
# Recall@k: fraction where true match is in top k
for k in k_values:
recall = np.mean(true_ranks < k)
results[f"recall_at_{k}"] = float(recall)
# MRR: Mean Reciprocal Rank
reciprocal_ranks = 1.0 / (true_ranks + 1) # +1 because ranks are 0-indexed
results["mrr"] = float(np.mean(reciprocal_ranks))
return results
def compute_bitext_accuracy(
similarity_matrix: np.ndarray,
num_negatives: int = 10,
) -> tuple[float, float, float]:
"""
Compute bitext mining accuracy.
For each true pair, sample random negative pairs and check if the model
correctly ranks the true pair higher.
Args:
similarity_matrix: (N, N) similarity matrix
num_negatives: Number of negative samples per true pair
Returns:
Tuple of (accuracy, avg_true_sim, avg_random_sim)
"""
n = similarity_matrix.shape[0]
# True pair similarities (diagonal)
true_similarities = np.diag(similarity_matrix)
# Sample random negative pairs
correct = 0
total = 0
random_sims = []
rng = np.random.default_rng(42)
for i in range(n):
true_sim = true_similarities[i]
# Sample random passage indices (not the true match)
neg_indices = rng.choice(
[j for j in range(n) if j != i],
size=min(num_negatives, n - 1),
replace=False,
)
for j in neg_indices:
neg_sim = similarity_matrix[i, j]
random_sims.append(neg_sim)
if true_sim > neg_sim:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0.0
avg_true = float(np.mean(true_similarities))
avg_random = float(np.mean(random_sims)) if random_sims else 0.0
return accuracy, avg_true, avg_random
def evaluate_model(
model,
benchmark_pairs: list[dict],
batch_size: int = 32,
max_pairs: Optional[int] = None,
progress_callback=None,
) -> EvaluationResults:
"""
Run full evaluation of a model on the benchmark.
Args:
model: EmbeddingModel instance
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
batch_size: Batch size for encoding
max_pairs: Maximum pairs to evaluate (for faster testing)
progress_callback: Optional callback(progress_fraction, message) for progress updates
Returns:
EvaluationResults with all metrics
"""
from collections import Counter
# Optionally limit pairs
if max_pairs and len(benchmark_pairs) > max_pairs:
benchmark_pairs = benchmark_pairs[:max_pairs]
# Extract texts
he_texts = [p["he"] for p in benchmark_pairs]
en_texts = [p["en"] for p in benchmark_pairs]
categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs)
n_total = len(he_texts)
n_batches = (n_total + batch_size - 1) // batch_size
def report_progress(phase, batch_idx, total_batches):
"""Report progress to callback if available."""
if progress_callback:
# Phase 1: Hebrew encoding (0-45%), Phase 2: English encoding (45-90%), Phase 3: Metrics (90-100%)
if phase == "hebrew":
progress = 0.45 * (batch_idx / total_batches)
elif phase == "english":
progress = 0.45 + 0.45 * (batch_idx / total_batches)
else:
progress = 0.9 + 0.1 * batch_idx # For final steps
progress_callback(progress, f"⏳ {phase.capitalize()}: {batch_idx}/{total_batches}")
# Encode Hebrew texts in batches
if progress_callback:
progress_callback(0, f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}")
he_embeddings_list = []
for i in range(0, len(he_texts), batch_size):
batch = he_texts[i:i + batch_size]
batch_emb = model.encode(
batch,
is_query=True,
batch_size=batch_size,
show_progress=False,
)
he_embeddings_list.append(batch_emb)
done = min(i + batch_size, len(he_texts))
batch_idx = (i // batch_size) + 1
if progress_callback:
progress_callback(0.45 * batch_idx / n_batches, f"⏳ Encoding Hebrew/Aramaic: {done:,}/{n_total:,}")
he_embeddings = np.vstack(he_embeddings_list)
# Encode English texts in batches
if progress_callback:
progress_callback(0.45, f"⏳ Encoding English texts: 0/{n_total:,}")
en_embeddings_list = []
for i in range(0, len(en_texts), batch_size):
batch = en_texts[i:i + batch_size]
batch_emb = model.encode(
batch,
is_query=False,
batch_size=batch_size,
show_progress=False,
)
en_embeddings_list.append(batch_emb)
done = min(i + batch_size, len(en_texts))
batch_idx = (i // batch_size) + 1
if progress_callback:
progress_callback(0.45 + 0.45 * batch_idx / n_batches, f"⏳ Encoding English: {done:,}/{n_total:,}")
en_embeddings = np.vstack(en_embeddings_list)
if progress_callback:
progress_callback(0.92, "⏳ Computing similarity matrix...")
similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings)
if progress_callback:
progress_callback(0.95, "⏳ Computing retrieval metrics...")
retrieval_metrics = compute_retrieval_metrics(similarity_matrix)
if progress_callback:
progress_callback(0.98, "⏳ Computing bitext accuracy...")
bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy(similarity_matrix)
if progress_callback:
progress_callback(1.0, "✅ Evaluation complete!")
return EvaluationResults(
model_id=model.model_id,
model_name=model.name,
recall_at_1=retrieval_metrics["recall_at_1"],
recall_at_5=retrieval_metrics["recall_at_5"],
recall_at_10=retrieval_metrics["recall_at_10"],
mrr=retrieval_metrics["mrr"],
bitext_accuracy=bitext_acc,
avg_true_pair_similarity=avg_true_sim,
avg_random_pair_similarity=avg_random_sim,
num_pairs=len(benchmark_pairs),
categories=dict(categories),
)
def evaluate_model_streaming(
model,
benchmark_pairs: list[dict],
batch_size: int = 32,
max_pairs: Optional[int] = None,
):
"""
Run evaluation with streaming progress updates.
Yields progress strings during encoding, then yields final EvaluationResults.
Args:
model: EmbeddingModel instance
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys
batch_size: Batch size for encoding
max_pairs: Maximum pairs to evaluate (for faster testing)
Yields:
Progress strings, then final EvaluationResults
"""
from collections import Counter
# Optionally limit pairs
if max_pairs and len(benchmark_pairs) > max_pairs:
benchmark_pairs = benchmark_pairs[:max_pairs]
# Extract texts
he_texts = [p["he"] for p in benchmark_pairs]
en_texts = [p["en"] for p in benchmark_pairs]
categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs)
n_total = len(he_texts)
# Encode Hebrew texts in batches with progress
yield f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}"
he_embeddings_list = []
for i in range(0, len(he_texts), batch_size):
batch = he_texts[i:i + batch_size]
batch_emb = model.encode(
batch,
is_query=True,
batch_size=batch_size,
show_progress=False,
)
he_embeddings_list.append(batch_emb)
done = min(i + batch_size, len(he_texts))
yield f"⏳ Encoding Hebrew/Aramaic texts: {done:,}/{n_total:,}"
he_embeddings = np.vstack(he_embeddings_list)
# Encode English texts in batches with progress
yield f"⏳ Encoding English texts: 0/{n_total:,}"
en_embeddings_list = []
for i in range(0, len(en_texts), batch_size):
batch = en_texts[i:i + batch_size]
batch_emb = model.encode(
batch,
is_query=False,
batch_size=batch_size,
show_progress=False,
)
en_embeddings_list.append(batch_emb)
done = min(i + batch_size, len(en_texts))
yield f"⏳ Encoding English texts: {done:,}/{n_total:,}"
en_embeddings = np.vstack(en_embeddings_list)
yield "⏳ Computing similarity matrix..."
similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings)
yield "⏳ Computing retrieval metrics..."
retrieval_metrics = compute_retrieval_metrics(similarity_matrix)
yield "⏳ Computing bitext accuracy..."
bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy(
similarity_matrix
)
# Yield final results
yield EvaluationResults(
model_id=model.model_id,
model_name=model.name,
recall_at_1=retrieval_metrics["recall_at_1"],
recall_at_5=retrieval_metrics["recall_at_5"],
recall_at_10=retrieval_metrics["recall_at_10"],
mrr=retrieval_metrics["mrr"],
bitext_accuracy=bitext_acc,
avg_true_pair_similarity=avg_true_sim,
avg_random_pair_similarity=avg_random_sim,
num_pairs=len(benchmark_pairs),
categories=dict(categories),
)
def evaluate_by_category(
model,
benchmark_pairs: list[dict],
batch_size: int = 32,
) -> dict[str, EvaluationResults]:
"""
Run evaluation broken down by category.
Args:
model: EmbeddingModel instance
benchmark_pairs: List of benchmark pairs
batch_size: Batch size for encoding
Returns:
Dict mapping category name to EvaluationResults
"""
from collections import defaultdict
# Group pairs by category
by_category = defaultdict(list)
for pair in benchmark_pairs:
category = pair.get("category", "Unknown")
by_category[category].append(pair)
results = {}
for category, pairs in by_category.items():
print(f"Evaluating category: {category} ({len(pairs)} pairs)")
results[category] = evaluate_model(model, pairs, batch_size=batch_size)
return results
def get_rank_distribution(
similarity_matrix: np.ndarray,
bins: list[int] = [1, 5, 10, 50, 100],
) -> dict[str, int]:
"""
Get distribution of true match ranks.
Args:
similarity_matrix: (N, N) similarity matrix
bins: Bin boundaries for histogram
Returns:
Dict mapping bin labels to counts
"""
n = similarity_matrix.shape[0]
rankings = np.argsort(-similarity_matrix, axis=1)
# Find true rank for each query
true_ranks = np.zeros(n, dtype=int)
for i in range(n):
true_ranks[i] = np.where(rankings[i] == i)[0][0]
# Create histogram
distribution = {}
prev_bin = 0
for bin_edge in bins:
count = np.sum((true_ranks >= prev_bin) & (true_ranks < bin_edge))
label = f"{prev_bin+1}-{bin_edge}" if prev_bin > 0 else f"Top {bin_edge}"
distribution[label] = int(count)
prev_bin = bin_edge
# Count remaining
remaining = np.sum(true_ranks >= bins[-1])
distribution[f">{bins[-1]}"] = int(remaining)
return distribution
if __name__ == "__main__":
# Test with sample data
print("Testing evaluation functions...")
# Create sample similarity matrix (perfect retrieval)
n = 100
perfect_matrix = np.eye(n) + np.random.randn(n, n) * 0.1
metrics = compute_retrieval_metrics(perfect_matrix)
print(f"Perfect retrieval metrics: {metrics}")
# Test with random matrix
random_matrix = np.random.randn(n, n)
random_matrix = random_matrix / np.linalg.norm(random_matrix, axis=1, keepdims=True)
random_matrix = np.dot(random_matrix, random_matrix.T)
metrics = compute_retrieval_metrics(random_matrix)
print(f"Random retrieval metrics: {metrics}")