|
|
""" |
|
|
Cross-lingual retrieval evaluation for Rabbinic embedding benchmark. |
|
|
|
|
|
Computes retrieval metrics to measure how well embedding models align |
|
|
Hebrew/Aramaic source texts with their English translations. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvaluationResults: |
|
|
"""Container for evaluation results.""" |
|
|
|
|
|
model_id: str |
|
|
model_name: str |
|
|
|
|
|
|
|
|
recall_at_1: float |
|
|
recall_at_5: float |
|
|
recall_at_10: float |
|
|
mrr: float |
|
|
|
|
|
|
|
|
bitext_accuracy: float |
|
|
avg_true_pair_similarity: float |
|
|
avg_random_pair_similarity: float |
|
|
|
|
|
|
|
|
num_pairs: int |
|
|
categories: dict[str, int] |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Convert to dictionary for JSON serialization.""" |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"model_name": self.model_name, |
|
|
"recall_at_1": self.recall_at_1, |
|
|
"recall_at_5": self.recall_at_5, |
|
|
"recall_at_10": self.recall_at_10, |
|
|
"mrr": self.mrr, |
|
|
"bitext_accuracy": self.bitext_accuracy, |
|
|
"avg_true_pair_similarity": self.avg_true_pair_similarity, |
|
|
"avg_random_pair_similarity": self.avg_random_pair_similarity, |
|
|
"num_pairs": self.num_pairs, |
|
|
"categories": self.categories, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: dict) -> "EvaluationResults": |
|
|
"""Create from dictionary.""" |
|
|
return cls(**data) |
|
|
|
|
|
|
|
|
def compute_similarity_matrix( |
|
|
query_embeddings: np.ndarray, |
|
|
passage_embeddings: np.ndarray, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute cosine similarity matrix between queries and passages. |
|
|
|
|
|
Assumes embeddings are already L2-normalized. |
|
|
|
|
|
Args: |
|
|
query_embeddings: (N, D) array of query embeddings |
|
|
passage_embeddings: (M, D) array of passage embeddings |
|
|
|
|
|
Returns: |
|
|
(N, M) similarity matrix |
|
|
""" |
|
|
return np.dot(query_embeddings, passage_embeddings.T) |
|
|
|
|
|
|
|
|
def compute_retrieval_metrics( |
|
|
similarity_matrix: np.ndarray, |
|
|
k_values: list[int] = [1, 5, 10], |
|
|
) -> dict[str, float]: |
|
|
""" |
|
|
Compute retrieval metrics from similarity matrix. |
|
|
|
|
|
Assumes the correct match for query i is passage i (diagonal). |
|
|
|
|
|
Args: |
|
|
similarity_matrix: (N, N) similarity matrix where diagonal is true matches |
|
|
k_values: List of k values for Recall@k |
|
|
|
|
|
Returns: |
|
|
Dict with recall@k and mrr values |
|
|
""" |
|
|
n = similarity_matrix.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
rankings = np.argsort(-similarity_matrix, axis=1) |
|
|
|
|
|
|
|
|
true_ranks = np.zeros(n, dtype=int) |
|
|
for i in range(n): |
|
|
|
|
|
true_ranks[i] = np.where(rankings[i] == i)[0][0] |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
for k in k_values: |
|
|
recall = np.mean(true_ranks < k) |
|
|
results[f"recall_at_{k}"] = float(recall) |
|
|
|
|
|
|
|
|
reciprocal_ranks = 1.0 / (true_ranks + 1) |
|
|
results["mrr"] = float(np.mean(reciprocal_ranks)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def compute_bitext_accuracy( |
|
|
similarity_matrix: np.ndarray, |
|
|
num_negatives: int = 10, |
|
|
) -> tuple[float, float, float]: |
|
|
""" |
|
|
Compute bitext mining accuracy. |
|
|
|
|
|
For each true pair, sample random negative pairs and check if the model |
|
|
correctly ranks the true pair higher. |
|
|
|
|
|
Args: |
|
|
similarity_matrix: (N, N) similarity matrix |
|
|
num_negatives: Number of negative samples per true pair |
|
|
|
|
|
Returns: |
|
|
Tuple of (accuracy, avg_true_sim, avg_random_sim) |
|
|
""" |
|
|
n = similarity_matrix.shape[0] |
|
|
|
|
|
|
|
|
true_similarities = np.diag(similarity_matrix) |
|
|
|
|
|
|
|
|
correct = 0 |
|
|
total = 0 |
|
|
random_sims = [] |
|
|
|
|
|
rng = np.random.default_rng(42) |
|
|
|
|
|
for i in range(n): |
|
|
true_sim = true_similarities[i] |
|
|
|
|
|
|
|
|
neg_indices = rng.choice( |
|
|
[j for j in range(n) if j != i], |
|
|
size=min(num_negatives, n - 1), |
|
|
replace=False, |
|
|
) |
|
|
|
|
|
for j in neg_indices: |
|
|
neg_sim = similarity_matrix[i, j] |
|
|
random_sims.append(neg_sim) |
|
|
|
|
|
if true_sim > neg_sim: |
|
|
correct += 1 |
|
|
total += 1 |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0.0 |
|
|
avg_true = float(np.mean(true_similarities)) |
|
|
avg_random = float(np.mean(random_sims)) if random_sims else 0.0 |
|
|
|
|
|
return accuracy, avg_true, avg_random |
|
|
|
|
|
|
|
|
def evaluate_model( |
|
|
model, |
|
|
benchmark_pairs: list[dict], |
|
|
batch_size: int = 32, |
|
|
max_pairs: Optional[int] = None, |
|
|
progress_callback=None, |
|
|
) -> EvaluationResults: |
|
|
""" |
|
|
Run full evaluation of a model on the benchmark. |
|
|
|
|
|
Args: |
|
|
model: EmbeddingModel instance |
|
|
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys |
|
|
batch_size: Batch size for encoding |
|
|
max_pairs: Maximum pairs to evaluate (for faster testing) |
|
|
progress_callback: Optional callback(progress_fraction, message) for progress updates |
|
|
|
|
|
Returns: |
|
|
EvaluationResults with all metrics |
|
|
""" |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
if max_pairs and len(benchmark_pairs) > max_pairs: |
|
|
benchmark_pairs = benchmark_pairs[:max_pairs] |
|
|
|
|
|
|
|
|
he_texts = [p["he"] for p in benchmark_pairs] |
|
|
en_texts = [p["en"] for p in benchmark_pairs] |
|
|
categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs) |
|
|
n_total = len(he_texts) |
|
|
n_batches = (n_total + batch_size - 1) // batch_size |
|
|
|
|
|
def report_progress(phase, batch_idx, total_batches): |
|
|
"""Report progress to callback if available.""" |
|
|
if progress_callback: |
|
|
|
|
|
if phase == "hebrew": |
|
|
progress = 0.45 * (batch_idx / total_batches) |
|
|
elif phase == "english": |
|
|
progress = 0.45 + 0.45 * (batch_idx / total_batches) |
|
|
else: |
|
|
progress = 0.9 + 0.1 * batch_idx |
|
|
progress_callback(progress, f"⏳ {phase.capitalize()}: {batch_idx}/{total_batches}") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0, f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}") |
|
|
|
|
|
he_embeddings_list = [] |
|
|
for i in range(0, len(he_texts), batch_size): |
|
|
batch = he_texts[i:i + batch_size] |
|
|
batch_emb = model.encode( |
|
|
batch, |
|
|
is_query=True, |
|
|
batch_size=batch_size, |
|
|
show_progress=False, |
|
|
) |
|
|
he_embeddings_list.append(batch_emb) |
|
|
done = min(i + batch_size, len(he_texts)) |
|
|
batch_idx = (i // batch_size) + 1 |
|
|
if progress_callback: |
|
|
progress_callback(0.45 * batch_idx / n_batches, f"⏳ Encoding Hebrew/Aramaic: {done:,}/{n_total:,}") |
|
|
|
|
|
he_embeddings = np.vstack(he_embeddings_list) |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.45, f"⏳ Encoding English texts: 0/{n_total:,}") |
|
|
|
|
|
en_embeddings_list = [] |
|
|
for i in range(0, len(en_texts), batch_size): |
|
|
batch = en_texts[i:i + batch_size] |
|
|
batch_emb = model.encode( |
|
|
batch, |
|
|
is_query=False, |
|
|
batch_size=batch_size, |
|
|
show_progress=False, |
|
|
) |
|
|
en_embeddings_list.append(batch_emb) |
|
|
done = min(i + batch_size, len(en_texts)) |
|
|
batch_idx = (i // batch_size) + 1 |
|
|
if progress_callback: |
|
|
progress_callback(0.45 + 0.45 * batch_idx / n_batches, f"⏳ Encoding English: {done:,}/{n_total:,}") |
|
|
|
|
|
en_embeddings = np.vstack(en_embeddings_list) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.92, "⏳ Computing similarity matrix...") |
|
|
similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.95, "⏳ Computing retrieval metrics...") |
|
|
retrieval_metrics = compute_retrieval_metrics(similarity_matrix) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.98, "⏳ Computing bitext accuracy...") |
|
|
bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy(similarity_matrix) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, "✅ Evaluation complete!") |
|
|
|
|
|
return EvaluationResults( |
|
|
model_id=model.model_id, |
|
|
model_name=model.name, |
|
|
recall_at_1=retrieval_metrics["recall_at_1"], |
|
|
recall_at_5=retrieval_metrics["recall_at_5"], |
|
|
recall_at_10=retrieval_metrics["recall_at_10"], |
|
|
mrr=retrieval_metrics["mrr"], |
|
|
bitext_accuracy=bitext_acc, |
|
|
avg_true_pair_similarity=avg_true_sim, |
|
|
avg_random_pair_similarity=avg_random_sim, |
|
|
num_pairs=len(benchmark_pairs), |
|
|
categories=dict(categories), |
|
|
) |
|
|
|
|
|
|
|
|
def evaluate_model_streaming( |
|
|
model, |
|
|
benchmark_pairs: list[dict], |
|
|
batch_size: int = 32, |
|
|
max_pairs: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Run evaluation with streaming progress updates. |
|
|
|
|
|
Yields progress strings during encoding, then yields final EvaluationResults. |
|
|
|
|
|
Args: |
|
|
model: EmbeddingModel instance |
|
|
benchmark_pairs: List of benchmark pairs with 'he', 'en', 'category' keys |
|
|
batch_size: Batch size for encoding |
|
|
max_pairs: Maximum pairs to evaluate (for faster testing) |
|
|
|
|
|
Yields: |
|
|
Progress strings, then final EvaluationResults |
|
|
""" |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
if max_pairs and len(benchmark_pairs) > max_pairs: |
|
|
benchmark_pairs = benchmark_pairs[:max_pairs] |
|
|
|
|
|
|
|
|
he_texts = [p["he"] for p in benchmark_pairs] |
|
|
en_texts = [p["en"] for p in benchmark_pairs] |
|
|
categories = Counter(p.get("category", "Unknown") for p in benchmark_pairs) |
|
|
n_total = len(he_texts) |
|
|
|
|
|
|
|
|
yield f"⏳ Encoding Hebrew/Aramaic texts: 0/{n_total:,}" |
|
|
he_embeddings_list = [] |
|
|
for i in range(0, len(he_texts), batch_size): |
|
|
batch = he_texts[i:i + batch_size] |
|
|
batch_emb = model.encode( |
|
|
batch, |
|
|
is_query=True, |
|
|
batch_size=batch_size, |
|
|
show_progress=False, |
|
|
) |
|
|
he_embeddings_list.append(batch_emb) |
|
|
done = min(i + batch_size, len(he_texts)) |
|
|
yield f"⏳ Encoding Hebrew/Aramaic texts: {done:,}/{n_total:,}" |
|
|
|
|
|
he_embeddings = np.vstack(he_embeddings_list) |
|
|
|
|
|
|
|
|
yield f"⏳ Encoding English texts: 0/{n_total:,}" |
|
|
en_embeddings_list = [] |
|
|
for i in range(0, len(en_texts), batch_size): |
|
|
batch = en_texts[i:i + batch_size] |
|
|
batch_emb = model.encode( |
|
|
batch, |
|
|
is_query=False, |
|
|
batch_size=batch_size, |
|
|
show_progress=False, |
|
|
) |
|
|
en_embeddings_list.append(batch_emb) |
|
|
done = min(i + batch_size, len(en_texts)) |
|
|
yield f"⏳ Encoding English texts: {done:,}/{n_total:,}" |
|
|
|
|
|
en_embeddings = np.vstack(en_embeddings_list) |
|
|
|
|
|
yield "⏳ Computing similarity matrix..." |
|
|
similarity_matrix = compute_similarity_matrix(he_embeddings, en_embeddings) |
|
|
|
|
|
yield "⏳ Computing retrieval metrics..." |
|
|
retrieval_metrics = compute_retrieval_metrics(similarity_matrix) |
|
|
|
|
|
yield "⏳ Computing bitext accuracy..." |
|
|
bitext_acc, avg_true_sim, avg_random_sim = compute_bitext_accuracy( |
|
|
similarity_matrix |
|
|
) |
|
|
|
|
|
|
|
|
yield EvaluationResults( |
|
|
model_id=model.model_id, |
|
|
model_name=model.name, |
|
|
recall_at_1=retrieval_metrics["recall_at_1"], |
|
|
recall_at_5=retrieval_metrics["recall_at_5"], |
|
|
recall_at_10=retrieval_metrics["recall_at_10"], |
|
|
mrr=retrieval_metrics["mrr"], |
|
|
bitext_accuracy=bitext_acc, |
|
|
avg_true_pair_similarity=avg_true_sim, |
|
|
avg_random_pair_similarity=avg_random_sim, |
|
|
num_pairs=len(benchmark_pairs), |
|
|
categories=dict(categories), |
|
|
) |
|
|
|
|
|
|
|
|
def evaluate_by_category( |
|
|
model, |
|
|
benchmark_pairs: list[dict], |
|
|
batch_size: int = 32, |
|
|
) -> dict[str, EvaluationResults]: |
|
|
""" |
|
|
Run evaluation broken down by category. |
|
|
|
|
|
Args: |
|
|
model: EmbeddingModel instance |
|
|
benchmark_pairs: List of benchmark pairs |
|
|
batch_size: Batch size for encoding |
|
|
|
|
|
Returns: |
|
|
Dict mapping category name to EvaluationResults |
|
|
""" |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
by_category = defaultdict(list) |
|
|
for pair in benchmark_pairs: |
|
|
category = pair.get("category", "Unknown") |
|
|
by_category[category].append(pair) |
|
|
|
|
|
results = {} |
|
|
for category, pairs in by_category.items(): |
|
|
print(f"Evaluating category: {category} ({len(pairs)} pairs)") |
|
|
results[category] = evaluate_model(model, pairs, batch_size=batch_size) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def get_rank_distribution( |
|
|
similarity_matrix: np.ndarray, |
|
|
bins: list[int] = [1, 5, 10, 50, 100], |
|
|
) -> dict[str, int]: |
|
|
""" |
|
|
Get distribution of true match ranks. |
|
|
|
|
|
Args: |
|
|
similarity_matrix: (N, N) similarity matrix |
|
|
bins: Bin boundaries for histogram |
|
|
|
|
|
Returns: |
|
|
Dict mapping bin labels to counts |
|
|
""" |
|
|
n = similarity_matrix.shape[0] |
|
|
rankings = np.argsort(-similarity_matrix, axis=1) |
|
|
|
|
|
|
|
|
true_ranks = np.zeros(n, dtype=int) |
|
|
for i in range(n): |
|
|
true_ranks[i] = np.where(rankings[i] == i)[0][0] |
|
|
|
|
|
|
|
|
distribution = {} |
|
|
prev_bin = 0 |
|
|
for bin_edge in bins: |
|
|
count = np.sum((true_ranks >= prev_bin) & (true_ranks < bin_edge)) |
|
|
label = f"{prev_bin+1}-{bin_edge}" if prev_bin > 0 else f"Top {bin_edge}" |
|
|
distribution[label] = int(count) |
|
|
prev_bin = bin_edge |
|
|
|
|
|
|
|
|
remaining = np.sum(true_ranks >= bins[-1]) |
|
|
distribution[f">{bins[-1]}"] = int(remaining) |
|
|
|
|
|
return distribution |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Testing evaluation functions...") |
|
|
|
|
|
|
|
|
n = 100 |
|
|
perfect_matrix = np.eye(n) + np.random.randn(n, n) * 0.1 |
|
|
|
|
|
metrics = compute_retrieval_metrics(perfect_matrix) |
|
|
print(f"Perfect retrieval metrics: {metrics}") |
|
|
|
|
|
|
|
|
random_matrix = np.random.randn(n, n) |
|
|
random_matrix = random_matrix / np.linalg.norm(random_matrix, axis=1, keepdims=True) |
|
|
random_matrix = np.dot(random_matrix, random_matrix.T) |
|
|
|
|
|
metrics = compute_retrieval_metrics(random_matrix) |
|
|
print(f"Random retrieval metrics: {metrics}") |
|
|
|
|
|
|