File size: 3,864 Bytes
173f28e
 
f56dbf3
173f28e
f56dbf3
173f28e
bf74331
173f28e
 
bf74331
 
 
 
f56dbf3
 
db0da0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf74331
 
 
 
 
 
 
 
 
 
db0da0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf74331
db0da0a
bf74331
 
 
 
 
 
db0da0a
bf74331
 
 
 
db0da0a
bf74331
 
 
 
db0da0a
 
 
 
 
 
 
bf74331
 
 
 
 
7c4c461
 
bf74331
 
 
db0da0a
bf74331
 
 
 
 
 
 
 
 
 
 
 
db0da0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations

import numpy as np
from datasets import load_dataset
from scipy.stats import spearmanr

from dataset_config import DatasetConfig


def _normalize(emb: np.ndarray) -> np.ndarray:
    """L2-normalize each row."""
    norms = np.linalg.norm(emb, axis=1, keepdims=True)
    return emb / norms


ALL_RETRIEVAL_METRICS = [
    "mrr",
    "map@5", "map@10",
    "ndcg@5", "ndcg@10",
    "precision@1", "precision@5", "precision@10",
    "recall@1", "recall@5", "recall@10",
]

DEFAULT_RETRIEVAL_METRICS = ["mrr", "recall@1", "recall@5", "recall@10"]


def _retrieval_metrics(
    emb_q: np.ndarray,
    emb_p: np.ndarray,
    metrics: list[str] | None = None,
) -> dict[str, float]:
    """Compute retrieval metrics assuming query i matches passage i."""
    if metrics is None:
        metrics = DEFAULT_RETRIEVAL_METRICS

    emb_q = _normalize(emb_q)
    emb_p = _normalize(emb_p)

    # Similarity matrix: (num_queries, num_passages)
    sims = emb_q @ emb_p.T

    n = sims.shape[0]
    sorted_indices = np.argsort(-sims, axis=1)
    ranks = np.array([int(np.where(sorted_indices[i] == i)[0][0]) for i in range(n)])

    results: dict[str, float] = {}

    for m in metrics:
        if m == "mrr":
            results["mrr"] = round(float(np.mean(1.0 / (ranks + 1))), 4)

        elif m.startswith("recall@"):
            k = int(m.split("@")[1])
            results[m] = round(float(np.mean(ranks < k)), 4)

        elif m.startswith("precision@"):
            k = int(m.split("@")[1])
            # Single relevant doc per query: precision@k = 1/k if hit, else 0
            results[m] = round(float(np.mean((ranks < k) / k)), 4)

        elif m.startswith("map@"):
            k = int(m.split("@")[1])
            # Single relevant doc: AP = 1/(rank+1) if rank < k, else 0
            ap = np.where(ranks < k, 1.0 / (ranks + 1), 0.0)
            results[m] = round(float(np.mean(ap)), 4)

        elif m.startswith("ndcg@"):
            k = int(m.split("@")[1])
            # Single relevant doc: DCG = 1/log2(rank+2) if rank < k, else 0
            # ideal DCG = 1/log2(2) = 1.0
            dcg = np.where(ranks < k, 1.0 / np.log2(ranks + 2), 0.0)
            results[m] = round(float(np.mean(dcg)), 4)

    return results


def evaluate_quality(
    model,
    ds_cfg: DatasetConfig | None = None,
    max_pairs: int | None = None,
    metrics: list[str] | None = None,
) -> dict[str, float]:
    """Evaluate embedding quality on a dataset.

    Returns a dict with either {"spearman": float} for scored datasets
    or selected retrieval metrics for pair datasets.
    """
    if ds_cfg is None:
        ds_cfg = DatasetConfig()

    if ds_cfg.data is not None:
        data = ds_cfg.data
    else:
        dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
        data = {col: list(dataset[col]) for col in dataset.column_names}
    queries = list(data[ds_cfg.query_col])
    passages = list(data[ds_cfg.passage_col])

    if max_pairs is not None and len(queries) > max_pairs:
        queries = queries[:max_pairs]
        passages = passages[:max_pairs]

    emb_q = model.encode(queries, is_query=True)
    emb_p = model.encode(passages, is_query=False)

    if ds_cfg.score_col is not None:
        # Scored mode: Spearman correlation
        scores = list(data[ds_cfg.score_col])
        if max_pairs is not None and len(scores) > max_pairs:
            scores = scores[:max_pairs]
        gold_scores = [s / ds_cfg.score_scale for s in scores]

        cos_sims = np.sum(emb_q * emb_p, axis=1) / (
            np.linalg.norm(emb_q, axis=1) * np.linalg.norm(emb_p, axis=1)
        )

        correlation, _ = spearmanr(cos_sims, gold_scores)
        return {"spearman": round(float(correlation), 4)}

    # Pair mode: retrieval metrics
    return _retrieval_metrics(emb_q, emb_p, metrics=metrics)