File size: 2,763 Bytes
129641e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from sklearn.metrics import ndcg_score
from src.retriever import ClinicalCaseRetriever, DummyRetriever

def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
    """
    Calculates retrieval metrics for a set of queries.

    Args:
        retriever_instance: An initialized ClinicalCaseRetriever instance.
        queries: A list of query strings.
        gold_ids: A list of the expected 'case_id' strings for each query.
        k: The number of top results to consider for Hit@k and NDCG@k.

    Returns:
        A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
    """
    # --- Initialization ---
    hits, reciprocal_ranks, ndcgs = [], [], []
    print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")

    # --- Process Each Query ---
    for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
        print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
        retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)

        # Safely extract IDs, handle missing keys
        retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
        print(f"Retrieved IDs: {retrieved_ids}")
        print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")

        # --- Calculate Metrics ---
        is_hit = int(gold in retrieved_ids)
        hits.append(is_hit)

        rank = 0
        if is_hit:
            rank = retrieved_ids.index(gold) + 1
            reciprocal_ranks.append(1.0 / rank)
        else:
            reciprocal_ranks.append(0.0)

        # NDCG calculation
        true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
        predicted_scores = np.asarray([scores])

        current_ndcg = 0.0
        if true_relevance.shape[1] > 0:
            ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
            current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
        ndcgs.append(current_ndcg)

        print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")

    # --- Aggregate Results ---
    avg_hit = np.mean(hits) if hits else 0.0
    avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
    avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0

    print(f"\n--- Overall Retrieval Results (k={k}) --- ")
    print(f"Average Hit@{k}: {avg_hit:.4f}")
    print(f"Average MRR:    {avg_mrr:.4f}") # Corrected spacing for alignment
    print(f"Average NDCG@{k}: {avg_ndcg:.4f}")

    return {f"Hit@{k}": avg_hit,
            f"MRR":      avg_mrr,
            f"NDCG@{k}": avg_ndcg}