File size: 2,763 Bytes
129641e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import numpy as np
from sklearn.metrics import ndcg_score
from src.retriever import ClinicalCaseRetriever, DummyRetriever
def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
"""
Calculates retrieval metrics for a set of queries.
Args:
retriever_instance: An initialized ClinicalCaseRetriever instance.
queries: A list of query strings.
gold_ids: A list of the expected 'case_id' strings for each query.
k: The number of top results to consider for Hit@k and NDCG@k.
Returns:
A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
"""
# --- Initialization ---
hits, reciprocal_ranks, ndcgs = [], [], []
print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")
# --- Process Each Query ---
for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)
# Safely extract IDs, handle missing keys
retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
print(f"Retrieved IDs: {retrieved_ids}")
print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")
# --- Calculate Metrics ---
is_hit = int(gold in retrieved_ids)
hits.append(is_hit)
rank = 0
if is_hit:
rank = retrieved_ids.index(gold) + 1
reciprocal_ranks.append(1.0 / rank)
else:
reciprocal_ranks.append(0.0)
# NDCG calculation
true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
predicted_scores = np.asarray([scores])
current_ndcg = 0.0
if true_relevance.shape[1] > 0:
ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
ndcgs.append(current_ndcg)
print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")
# --- Aggregate Results ---
avg_hit = np.mean(hits) if hits else 0.0
avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0
print(f"\n--- Overall Retrieval Results (k={k}) --- ")
print(f"Average Hit@{k}: {avg_hit:.4f}")
print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment
print(f"Average NDCG@{k}: {avg_ndcg:.4f}")
return {f"Hit@{k}": avg_hit,
f"MRR": avg_mrr,
f"NDCG@{k}": avg_ndcg} |