ai-agent-app / scripts /eval_embeddings.py
MinhTai's picture
deploy: ccb63e1
dd6cc27
#!/usr/bin/env python3
"""
Embedding quality evaluation script.
Benchmarks embedding models on retrieval effectiveness using historical solution logs
as ground truth (query → used_knowledge_ids relevance judgments).
Usage:
python scripts/eval_embeddings.py [--model MODEL_NAME] [--samples N]
Models to compare (if no --model specified):
- all-MiniLM-L6-v2 (baseline)
- paraphrase-multilingual-MiniLM-L12-v2
- sentence-transformers/msmarco-MiniLM-L6-en
- keepitreal/vietnamese-sbert (if available)
"""
import argparse
import hashlib
import json
import logging
import os
import sys
from collections import defaultdict
from typing import Optional
import numpy as np
# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(project_root, 'backend'))
from app.math_wiki.storage.db import _get_conn, _ensure_tables
from app.math_wiki.storage.vectors import embed_texts, build_vector_index, VectorIndex
from app.math_wiki.schemas import WikiUnit
from app.config import get_settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_solution_logs(limit: int = 200) -> list[dict]:
"""Fetch recent solution logs with used_knowledge_ids for relevance judgments."""
with _get_conn() as conn:
_ensure_tables(conn)
rows = conn.execute(
"""
SELECT problem_text, used_knowledge_ids
FROM solution_logs
WHERE json_array_length(used_knowledge_ids) > 0
ORDER BY created_at DESC
LIMIT ?
""",
(limit,),
).fetchall()
return [{"query": r["problem_text"], "relevant": json.loads(r["used_knowledge_ids"])} for r in rows]
def get_all_units() -> list[WikiUnit]:
"""Load all wiki units from DB."""
with _get_conn() as conn:
_ensure_tables(conn)
rows = conn.execute("SELECT * FROM wiki_units WHERE deleted = FALSE").fetchall()
return [
WikiUnit(
id=r["id"],
type=r["type"],
topic=r["topic"],
subtopic=r["subtopic"],
content=r["content"],
problem_ids=json.loads(r["problem_ids"]),
)
for r in rows
]
def _load_eval_model(model_name: str):
if model_name == "BAAI/bge-m3":
from FlagEmbedding import BGEM3FlagModel
return ("bge-m3", BGEM3FlagModel(model_name, use_fp16=False))
else:
from sentence_transformers import SentenceTransformer
return ("st", SentenceTransformer(model_name, device="cpu"))
def _encode(model_tuple, texts, prefix="passage"):
kind, model = model_tuple
if kind == "bge-m3":
prefixed = [f"{prefix}: {t}" for t in texts]
return model.encode(prefixed, return_dense=True, return_sparse=False, return_colbert_vecs=False)["dense_vecs"]
return model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
def evaluate_model(model_name: str, queries: list[dict], units: list[WikiUnit], top_k: int = 5) -> dict:
"""Evaluate an embedding model on retrieval effectiveness."""
logger.info("Evaluating model: %s", model_name)
try:
model_tuple = _load_eval_model(model_name)
except Exception as exc:
logger.error("Failed to load model %s: %s", model_name, exc)
return {"model": model_name, "error": str(exc)}
unit_texts = [u.content for u in units]
unit_embeds = _encode(model_tuple, unit_texts, prefix="passage")
dim = unit_embeds.shape[1]
import faiss
index = faiss.IndexFlatL2(dim)
index.add(unit_embeds.astype(np.float32))
id_map = [u.id for u in units]
mrr_scores = []
p_at_k_scores = []
query_embeds = _encode(model_tuple, [q["query"] for q in queries], prefix="query")
for q_vec, query_data in zip(query_embeds, queries):
q_vec_np = np.array([q_vec], dtype=np.float32)
_, indices = index.search(q_vec_np, top_k)
retrieved_ids = [id_map[i] for i in indices[0] if i >= 0]
relevant = set(query_data["relevant"])
# Precision@k
hits = [rid for rid in retrieved_ids if rid in relevant]
p_at_k_scores.append(len(hits) / top_k)
# MRR
rank = next((i + 1 for i, rid in enumerate(retrieved_ids) if rid in relevant), None)
mrr_scores.append(1.0 / rank if rank else 0.0)
return {
"model": model_name,
"samples": len(queries),
"mrr": round(sum(mrr_scores) / len(mrr_scores), 4),
"p@5": round(sum(p_at_k_scores) / len(p_at_k_scores), 4),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=None, help="Single model to evaluate (default: all)")
parser.add_argument("--samples", type=int, default=200, help="Number of query samples")
parser.add_argument("--k", type=int, default=5, help="Top-k for metrics")
args = parser.parse_args()
# Load data
logger.info("Loading evaluation data...")
queries = get_solution_logs(limit=args.samples)
if not queries:
logger.error("No solution logs available. Run the system with some activity first.")
sys.exit(1)
units = get_all_units()
if len(units) < 2:
logger.error("Need at least 2 wiki units to evaluate.")
sys.exit(1)
logger.info("Loaded %d queries, %d units", len(queries), len(units))
models_to_test = [
args.model,
] if args.model else [
"BAAI/bge-m3",
"all-MiniLM-L6-v2",
"paraphrase-multilingual-MiniLM-L12-v2",
"keepitreal/vietnamese-sbert",
]
results = []
for model_name in models_to_test:
try:
metrics = evaluate_model(model_name, queries, units, top_k=args.k)
results.append(metrics)
except Exception as exc:
logger.exception("Failed to evaluate %s: %s", model_name, exc)
results.append({"model": model_name, "error": str(exc)})
# Print comparison table
print("\n=== Embedding Quality Evaluation ===")
print(f"{'Model':<45} {'MRR':>6} {'P@5':>6} {'Samples':>8}")
print("-" * 70)
for r in results:
if "error" in r:
print(f"{r['model']:<45} ERROR: {r['error']}")
else:
print(f"{r['model']:<45} {r['mrr']:>6} {r['p@5']:>6} {r['samples']:>8}")
# Suggest switch if improvement >30%
if len(results) >= 2 and "error" not in results[0] and "error" not in results[1]:
baseline = results[0]
best = max(results, key=lambda x: x.get("mrr", 0))
if best != baseline:
improvement = (best["mrr"] - baseline["mrr"]) / baseline["mrr"] if baseline["mrr"] > 0 else 0
if improvement > 0.3:
print(f"\n→ {best['model']} improves MRR by {improvement*100:.1f}% over baseline.")
print(f" Consider setting embedding_model_name = \"{best['model']}\" in config.")
else:
print(f"\nNo model exceeds baseline by >30%. Keep current model.")
if __name__ == "__main__":
main()