| from __future__ import annotations |
|
|
| import math |
| from collections import defaultdict |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| from .data import CachedSplit |
| from .model import IMRNN |
|
|
| try: |
| import faiss |
| except ImportError: |
| faiss = None |
|
|
|
|
| def _build_search_index(doc_embeddings: np.ndarray): |
| if faiss is None: |
| return None |
| index = faiss.IndexFlatIP(doc_embeddings.shape[1]) |
| index.add(doc_embeddings.astype("float32")) |
| return index |
|
|
|
|
| def _search( |
| index, |
| all_document_embeddings: np.ndarray, |
| query_embedding: np.ndarray, |
| k: int, |
| ) -> tuple[np.ndarray, np.ndarray]: |
| if index is None: |
| scores = all_document_embeddings @ query_embedding |
| top_indices = np.argpartition(-scores, min(k, len(scores) - 1))[:k] |
| top_scores = scores[top_indices] |
| order = np.argsort(-top_scores) |
| return top_scores[order], top_indices[order] |
| query_embedding = query_embedding.reshape(1, -1).astype("float32") |
| scores, indices = index.search(query_embedding, k) |
| return scores[0], indices[0] |
|
|
|
|
| def _compute_metrics(ranked_doc_ids: list[str], qrel: dict[str, int], k_values: list[int]) -> dict[str, float]: |
| metrics: dict[str, float] = {} |
| for k in k_values: |
| top_docs = ranked_doc_ids[:k] |
| mrr = 0.0 |
| for rank, doc_id in enumerate(top_docs, start=1): |
| if qrel.get(doc_id, 0) > 0: |
| mrr = 1.0 / rank |
| break |
| metrics[f"MRR@{k}"] = mrr |
|
|
| total_relevant = sum(1 for rel in qrel.values() if rel > 0) |
| retrieved_relevant = sum(1 for doc_id in top_docs if qrel.get(doc_id, 0) > 0) |
| metrics[f"Recall@{k}"] = retrieved_relevant / total_relevant if total_relevant else 0.0 |
|
|
| dcg = 0.0 |
| ideal_relevances = sorted(qrel.values(), reverse=True)[:k] |
| for rank, doc_id in enumerate(top_docs, start=1): |
| relevance = qrel.get(doc_id, 0) |
| if relevance > 0: |
| dcg += (2**relevance - 1) / math.log2(rank + 1) |
| idcg = 0.0 |
| for rank, relevance in enumerate(ideal_relevances, start=1): |
| if relevance > 0: |
| idcg += (2**relevance - 1) / math.log2(rank + 1) |
| metrics[f"NDCG@{k}"] = dcg / idcg if idcg else 0.0 |
| return metrics |
|
|
|
|
| def evaluate_model( |
| model: IMRNN, |
| cached_split: CachedSplit, |
| device: str, |
| feedback_k: int = 100, |
| ranking_k: int = 100, |
| k_values: Optional[List[int]] = None, |
| ) -> dict[str, float]: |
| if k_values is None: |
| k_values = [10] |
|
|
| model.eval() |
|
|
| document_ids = sorted( |
| doc_id for doc_id in cached_split.split.corpus.keys() if doc_id in cached_split.document_embeddings |
| ) |
| document_tensor = torch.stack( |
| [cached_split.document_embeddings[doc_id].float() for doc_id in document_ids], dim=0 |
| ).to(device) |
| projected_documents = ( |
| F.normalize(model.project(document_tensor), p=2, dim=-1).detach().cpu().numpy() |
| ) |
| index = _build_search_index(projected_documents) |
|
|
| aggregated = defaultdict(list) |
|
|
| with torch.no_grad(): |
| for qid, query_embedding in tqdm(cached_split.query_embeddings.items(), desc="evaluate", leave=False): |
| if qid not in cached_split.split.qrels: |
| continue |
|
|
| base_query = F.normalize(model.project(query_embedding.float().unsqueeze(0).to(device)), p=2, dim=-1) |
| scores, indices = _search( |
| index=index, |
| all_document_embeddings=projected_documents, |
| query_embedding=base_query.squeeze(0).detach().cpu().numpy(), |
| k=min(feedback_k, len(document_ids)), |
| ) |
|
|
| candidate_ids = [document_ids[idx] for idx in indices if 0 <= idx < len(document_ids)] |
| if not candidate_ids: |
| continue |
|
|
| candidate_embeddings = torch.stack( |
| [cached_split.document_embeddings[doc_id].float() for doc_id in candidate_ids], |
| dim=0, |
| ).to(device) |
|
|
| _, _, adapted_scores = model.score_candidates(query_embedding.float().to(device), candidate_embeddings) |
| adapted_scores = adapted_scores.cpu().tolist() |
| reranked = [ |
| doc_id for doc_id, _ in sorted(zip(candidate_ids, adapted_scores), key=lambda item: item[1], reverse=True) |
| ][:ranking_k] |
|
|
| metrics = _compute_metrics(reranked, cached_split.split.qrels[qid], k_values) |
| for name, value in metrics.items(): |
| aggregated[name].append(value) |
|
|
| return {metric: float(np.mean(values)) for metric, values in aggregated.items()} |
|
|