import numpy as np from sklearn.metrics import ndcg_score from src.retriever import ClinicalCaseRetriever, DummyRetriever def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None: """ Calculates retrieval metrics for a set of queries. Args: retriever_instance: An initialized ClinicalCaseRetriever instance. queries: A list of query strings. gold_ids: A list of the expected 'case_id' strings for each query. k: The number of top results to consider for Hit@k and NDCG@k. Returns: A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error. """ # --- Initialization --- hits, reciprocal_ranks, ndcgs = [], [], [] print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...") # --- Process Each Query --- for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)): print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')") retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True) # Safely extract IDs, handle missing keys retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases] print(f"Retrieved IDs: {retrieved_ids}") print(f"Retrieved Scores: {[round(s, 4) for s in scores]}") # --- Calculate Metrics --- is_hit = int(gold in retrieved_ids) hits.append(is_hit) rank = 0 if is_hit: rank = retrieved_ids.index(gold) + 1 reciprocal_ranks.append(1.0 / rank) else: reciprocal_ranks.append(0.0) # NDCG calculation true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]]) predicted_scores = np.asarray([scores]) current_ndcg = 0.0 if true_relevance.shape[1] > 0: ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k) ndcgs.append(current_ndcg) print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}") # --- Aggregate Results --- avg_hit = np.mean(hits) if hits else 0.0 avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0 avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0 print(f"\n--- Overall Retrieval Results (k={k}) --- ") print(f"Average Hit@{k}: {avg_hit:.4f}") print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment print(f"Average NDCG@{k}: {avg_ndcg:.4f}") return {f"Hit@{k}": avg_hit, f"MRR": avg_mrr, f"NDCG@{k}": avg_ndcg}