File size: 7,093 Bytes
ecd70d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python3
"""
Embedding quality evaluation script.

Benchmarks embedding models on retrieval effectiveness using historical solution logs
as ground truth (query → used_knowledge_ids relevance judgments).

Usage:
    python scripts/eval_embeddings.py [--model MODEL_NAME] [--samples N]

Models to compare (if no --model specified):
    - all-MiniLM-L6-v2 (baseline)
    - paraphrase-multilingual-MiniLM-L12-v2
    - sentence-transformers/msmarco-MiniLM-L6-en
    - keepitreal/vietnamese-sbert (if available)
"""

import argparse
import hashlib
import json
import logging
import os
import sys
from collections import defaultdict
from typing import Optional

import numpy as np

# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(project_root, 'backend'))

from app.math_wiki.storage.db import _get_conn, _ensure_tables
from app.math_wiki.storage.vectors import embed_texts, build_vector_index, VectorIndex
from app.math_wiki.schemas import WikiUnit
from app.config import get_settings

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_solution_logs(limit: int = 200) -> list[dict]:
    """Fetch recent solution logs with used_knowledge_ids for relevance judgments."""
    with _get_conn() as conn:
        _ensure_tables(conn)
        rows = conn.execute(
            """
            SELECT problem_text, used_knowledge_ids
            FROM solution_logs
            WHERE json_array_length(used_knowledge_ids) > 0
            ORDER BY created_at DESC
            LIMIT ?
            """,
            (limit,),
        ).fetchall()
    return [{"query": r["problem_text"], "relevant": json.loads(r["used_knowledge_ids"])} for r in rows]


def get_all_units() -> list[WikiUnit]:
    """Load all wiki units from DB."""
    with _get_conn() as conn:
        _ensure_tables(conn)
        rows = conn.execute("SELECT * FROM wiki_units WHERE deleted = FALSE").fetchall()
    return [
        WikiUnit(
            id=r["id"],
            type=r["type"],
            topic=r["topic"],
            subtopic=r["subtopic"],
            content=r["content"],
            problem_ids=json.loads(r["problem_ids"]),
        )
        for r in rows
    ]


def _load_eval_model(model_name: str):
    if model_name == "BAAI/bge-m3":
        from FlagEmbedding import BGEM3FlagModel
        return ("bge-m3", BGEM3FlagModel(model_name, use_fp16=False))
    else:
        from sentence_transformers import SentenceTransformer
        return ("st", SentenceTransformer(model_name, device="cpu"))


def _encode(model_tuple, texts, prefix="passage"):
    kind, model = model_tuple
    if kind == "bge-m3":
        prefixed = [f"{prefix}: {t}" for t in texts]
        return model.encode(prefixed, return_dense=True, return_sparse=False, return_colbert_vecs=False)["dense_vecs"]
    return model.encode(texts, convert_to_numpy=True, show_progress_bar=False)


def evaluate_model(model_name: str, queries: list[dict], units: list[WikiUnit], top_k: int = 5) -> dict:
    """Evaluate an embedding model on retrieval effectiveness."""
    logger.info("Evaluating model: %s", model_name)

    try:
        model_tuple = _load_eval_model(model_name)
    except Exception as exc:
        logger.error("Failed to load model %s: %s", model_name, exc)
        return {"model": model_name, "error": str(exc)}

    unit_texts = [u.content for u in units]
    unit_embeds = _encode(model_tuple, unit_texts, prefix="passage")

    dim = unit_embeds.shape[1]
    import faiss
    index = faiss.IndexFlatL2(dim)
    index.add(unit_embeds.astype(np.float32))
    id_map = [u.id for u in units]

    mrr_scores = []
    p_at_k_scores = []
    query_embeds = _encode(model_tuple, [q["query"] for q in queries], prefix="query")

    for q_vec, query_data in zip(query_embeds, queries):
        q_vec_np = np.array([q_vec], dtype=np.float32)
        _, indices = index.search(q_vec_np, top_k)
        retrieved_ids = [id_map[i] for i in indices[0] if i >= 0]
        relevant = set(query_data["relevant"])

        # Precision@k
        hits = [rid for rid in retrieved_ids if rid in relevant]
        p_at_k_scores.append(len(hits) / top_k)

        # MRR
        rank = next((i + 1 for i, rid in enumerate(retrieved_ids) if rid in relevant), None)
        mrr_scores.append(1.0 / rank if rank else 0.0)

    return {
        "model": model_name,
        "samples": len(queries),
        "mrr": round(sum(mrr_scores) / len(mrr_scores), 4),
        "p@5": round(sum(p_at_k_scores) / len(p_at_k_scores), 4),
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default=None, help="Single model to evaluate (default: all)")
    parser.add_argument("--samples", type=int, default=200, help="Number of query samples")
    parser.add_argument("--k", type=int, default=5, help="Top-k for metrics")
    args = parser.parse_args()

    # Load data
    logger.info("Loading evaluation data...")
    queries = get_solution_logs(limit=args.samples)
    if not queries:
        logger.error("No solution logs available. Run the system with some activity first.")
        sys.exit(1)

    units = get_all_units()
    if len(units) < 2:
        logger.error("Need at least 2 wiki units to evaluate.")
        sys.exit(1)

    logger.info("Loaded %d queries, %d units", len(queries), len(units))

    models_to_test = [
        args.model,
    ] if args.model else [
        "BAAI/bge-m3",
        "all-MiniLM-L6-v2",
        "paraphrase-multilingual-MiniLM-L12-v2",
        "keepitreal/vietnamese-sbert",
    ]

    results = []
    for model_name in models_to_test:
        try:
            metrics = evaluate_model(model_name, queries, units, top_k=args.k)
            results.append(metrics)
        except Exception as exc:
            logger.exception("Failed to evaluate %s: %s", model_name, exc)
            results.append({"model": model_name, "error": str(exc)})

    # Print comparison table
    print("\n=== Embedding Quality Evaluation ===")
    print(f"{'Model':<45} {'MRR':>6} {'P@5':>6} {'Samples':>8}")
    print("-" * 70)
    for r in results:
        if "error" in r:
            print(f"{r['model']:<45} ERROR: {r['error']}")
        else:
            print(f"{r['model']:<45} {r['mrr']:>6} {r['p@5']:>6} {r['samples']:>8}")

    # Suggest switch if improvement >30%
    if len(results) >= 2 and "error" not in results[0] and "error" not in results[1]:
        baseline = results[0]
        best = max(results, key=lambda x: x.get("mrr", 0))
        if best != baseline:
            improvement = (best["mrr"] - baseline["mrr"]) / baseline["mrr"] if baseline["mrr"] > 0 else 0
            if improvement > 0.3:
                print(f"\n→ {best['model']} improves MRR by {improvement*100:.1f}% over baseline.")
                print(f"  Consider setting embedding_model_name = \"{best['model']}\" in config.")
            else:
                print(f"\nNo model exceeds baseline by >30%. Keep current model.")


if __name__ == "__main__":
    main()