IMRNNs / src /imrnns /evaluation.py
yashsaxena21's picture
Upload folder using huggingface_hub
11bc1ef verified
from __future__ import annotations
import math
from collections import defaultdict
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from .data import CachedSplit
from .model import IMRNN
try:
import faiss # type: ignore
except ImportError: # pragma: no cover
faiss = None
def _build_search_index(doc_embeddings: np.ndarray):
if faiss is None:
return None
index = faiss.IndexFlatIP(doc_embeddings.shape[1])
index.add(doc_embeddings.astype("float32"))
return index
def _search(
index,
all_document_embeddings: np.ndarray,
query_embedding: np.ndarray,
k: int,
) -> tuple[np.ndarray, np.ndarray]:
if index is None:
scores = all_document_embeddings @ query_embedding
top_indices = np.argpartition(-scores, min(k, len(scores) - 1))[:k]
top_scores = scores[top_indices]
order = np.argsort(-top_scores)
return top_scores[order], top_indices[order]
query_embedding = query_embedding.reshape(1, -1).astype("float32")
scores, indices = index.search(query_embedding, k)
return scores[0], indices[0]
def _compute_metrics(ranked_doc_ids: list[str], qrel: dict[str, int], k_values: list[int]) -> dict[str, float]:
metrics: dict[str, float] = {}
for k in k_values:
top_docs = ranked_doc_ids[:k]
mrr = 0.0
for rank, doc_id in enumerate(top_docs, start=1):
if qrel.get(doc_id, 0) > 0:
mrr = 1.0 / rank
break
metrics[f"MRR@{k}"] = mrr
total_relevant = sum(1 for rel in qrel.values() if rel > 0)
retrieved_relevant = sum(1 for doc_id in top_docs if qrel.get(doc_id, 0) > 0)
metrics[f"Recall@{k}"] = retrieved_relevant / total_relevant if total_relevant else 0.0
dcg = 0.0
ideal_relevances = sorted(qrel.values(), reverse=True)[:k]
for rank, doc_id in enumerate(top_docs, start=1):
relevance = qrel.get(doc_id, 0)
if relevance > 0:
dcg += (2**relevance - 1) / math.log2(rank + 1)
idcg = 0.0
for rank, relevance in enumerate(ideal_relevances, start=1):
if relevance > 0:
idcg += (2**relevance - 1) / math.log2(rank + 1)
metrics[f"NDCG@{k}"] = dcg / idcg if idcg else 0.0
return metrics
def evaluate_model(
model: IMRNN,
cached_split: CachedSplit,
device: str,
feedback_k: int = 100,
ranking_k: int = 100,
k_values: Optional[List[int]] = None,
) -> dict[str, float]:
if k_values is None:
k_values = [10]
model.eval()
document_ids = sorted(
doc_id for doc_id in cached_split.split.corpus.keys() if doc_id in cached_split.document_embeddings
)
document_tensor = torch.stack(
[cached_split.document_embeddings[doc_id].float() for doc_id in document_ids], dim=0
).to(device)
projected_documents = (
F.normalize(model.project(document_tensor), p=2, dim=-1).detach().cpu().numpy()
)
index = _build_search_index(projected_documents)
aggregated = defaultdict(list)
with torch.no_grad():
for qid, query_embedding in tqdm(cached_split.query_embeddings.items(), desc="evaluate", leave=False):
if qid not in cached_split.split.qrels:
continue
base_query = F.normalize(model.project(query_embedding.float().unsqueeze(0).to(device)), p=2, dim=-1)
scores, indices = _search(
index=index,
all_document_embeddings=projected_documents,
query_embedding=base_query.squeeze(0).detach().cpu().numpy(),
k=min(feedback_k, len(document_ids)),
)
candidate_ids = [document_ids[idx] for idx in indices if 0 <= idx < len(document_ids)]
if not candidate_ids:
continue
candidate_embeddings = torch.stack(
[cached_split.document_embeddings[doc_id].float() for doc_id in candidate_ids],
dim=0,
).to(device)
_, _, adapted_scores = model.score_candidates(query_embedding.float().to(device), candidate_embeddings)
adapted_scores = adapted_scores.cpu().tolist()
reranked = [
doc_id for doc_id, _ in sorted(zip(candidate_ids, adapted_scores), key=lambda item: item[1], reverse=True)
][:ranking_k]
metrics = _compute_metrics(reranked, cached_split.split.qrels[qid], k_values)
for name, value in metrics.items():
aggregated[name].append(value)
return {metric: float(np.mean(values)) for metric, values in aggregated.items()}