Spaces:
Running
Running
| from __future__ import annotations | |
| import numpy as np | |
| from datasets import load_dataset | |
| from scipy.stats import spearmanr | |
| from dataset_config import DatasetConfig | |
| def _normalize(emb: np.ndarray) -> np.ndarray: | |
| """L2-normalize each row.""" | |
| norms = np.linalg.norm(emb, axis=1, keepdims=True) | |
| return emb / norms | |
| ALL_RETRIEVAL_METRICS = [ | |
| "mrr", | |
| "map@5", "map@10", | |
| "ndcg@5", "ndcg@10", | |
| "precision@1", "precision@5", "precision@10", | |
| "recall@1", "recall@5", "recall@10", | |
| ] | |
| DEFAULT_RETRIEVAL_METRICS = ["mrr", "recall@1", "recall@5", "recall@10"] | |
| def _retrieval_metrics( | |
| emb_q: np.ndarray, | |
| emb_p: np.ndarray, | |
| metrics: list[str] | None = None, | |
| ) -> dict[str, float]: | |
| """Compute retrieval metrics assuming query i matches passage i.""" | |
| if metrics is None: | |
| metrics = DEFAULT_RETRIEVAL_METRICS | |
| emb_q = _normalize(emb_q) | |
| emb_p = _normalize(emb_p) | |
| # Similarity matrix: (num_queries, num_passages) | |
| sims = emb_q @ emb_p.T | |
| n = sims.shape[0] | |
| sorted_indices = np.argsort(-sims, axis=1) | |
| ranks = np.array([int(np.where(sorted_indices[i] == i)[0][0]) for i in range(n)]) | |
| results: dict[str, float] = {} | |
| for m in metrics: | |
| if m == "mrr": | |
| results["mrr"] = round(float(np.mean(1.0 / (ranks + 1))), 4) | |
| elif m.startswith("recall@"): | |
| k = int(m.split("@")[1]) | |
| results[m] = round(float(np.mean(ranks < k)), 4) | |
| elif m.startswith("precision@"): | |
| k = int(m.split("@")[1]) | |
| # Single relevant doc per query: precision@k = 1/k if hit, else 0 | |
| results[m] = round(float(np.mean((ranks < k) / k)), 4) | |
| elif m.startswith("map@"): | |
| k = int(m.split("@")[1]) | |
| # Single relevant doc: AP = 1/(rank+1) if rank < k, else 0 | |
| ap = np.where(ranks < k, 1.0 / (ranks + 1), 0.0) | |
| results[m] = round(float(np.mean(ap)), 4) | |
| elif m.startswith("ndcg@"): | |
| k = int(m.split("@")[1]) | |
| # Single relevant doc: DCG = 1/log2(rank+2) if rank < k, else 0 | |
| # ideal DCG = 1/log2(2) = 1.0 | |
| dcg = np.where(ranks < k, 1.0 / np.log2(ranks + 2), 0.0) | |
| results[m] = round(float(np.mean(dcg)), 4) | |
| return results | |
| def evaluate_quality( | |
| model, | |
| ds_cfg: DatasetConfig | None = None, | |
| max_pairs: int | None = None, | |
| metrics: list[str] | None = None, | |
| ) -> dict[str, float]: | |
| """Evaluate embedding quality on a dataset. | |
| Returns a dict with either {"spearman": float} for scored datasets | |
| or selected retrieval metrics for pair datasets. | |
| """ | |
| if ds_cfg is None: | |
| ds_cfg = DatasetConfig() | |
| if ds_cfg.data is not None: | |
| data = ds_cfg.data | |
| else: | |
| dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split) | |
| data = {col: list(dataset[col]) for col in dataset.column_names} | |
| queries = list(data[ds_cfg.query_col]) | |
| passages = list(data[ds_cfg.passage_col]) | |
| if max_pairs is not None and len(queries) > max_pairs: | |
| queries = queries[:max_pairs] | |
| passages = passages[:max_pairs] | |
| emb_q = model.encode(queries, is_query=True) | |
| emb_p = model.encode(passages, is_query=False) | |
| if ds_cfg.score_col is not None: | |
| # Scored mode: Spearman correlation | |
| scores = list(data[ds_cfg.score_col]) | |
| if max_pairs is not None and len(scores) > max_pairs: | |
| scores = scores[:max_pairs] | |
| gold_scores = [s / ds_cfg.score_scale for s in scores] | |
| cos_sims = np.sum(emb_q * emb_p, axis=1) / ( | |
| np.linalg.norm(emb_q, axis=1) * np.linalg.norm(emb_p, axis=1) | |
| ) | |
| correlation, _ = spearmanr(cos_sims, gold_scores) | |
| return {"spearman": round(float(correlation), 4)} | |
| # Pair mode: retrieval metrics | |
| return _retrieval_metrics(emb_q, emb_p, metrics=metrics) | |