File size: 1,733 Bytes
6d882b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import List


def calculate_hit_rate(retrieved_ids: List[str], relevant_id: str) -> int:
    """
    Hit rate: Checks if the relevant document is present in the retrieved results.

    Args:
        retrieved_ids: List of retrieved document IDs.
        relevant_id: The ID of the correct/relevant document.

    Returns:
        1 if hit, 0 if miss.
    """
    return 1 if relevant_id in retrieved_ids else 0


def calculate_mrr(retrieved_ids: List[str], relevant_id: str) -> float:
    """
    Mean Reciprocal Rank: The reciprocal of the rank of the first relevant document.
    Higher rank means better performance (max = 1.0).

    Args:
        retrieved_ids: List of retrieved document IDs.
        relevant_id: The ID of the correct/relevant document.

    Returns:
        1/rank if found, 0 if not found.
    """
    try:
        rank = retrieved_ids.index(relevant_id) + 1  # +1 because index starts from 0
        return 1.0 / rank
    except ValueError:
        return 0.0


def calculate_precision_at_k(
    retrieved_ids: List[str], relevant_ids: List[str], k: int = None
) -> float:
    """
    Precision@K: How many relevant documents are in the top-K results.

    Args:
        retrieved_ids: List of retrieved document IDs.
        relevant_ids: List of relevant document IDs (can be multiple).
        k: Cut-off point (if None, use all retrieved_ids).

    Returns:
        Precision score (0.0 - 1.0).
    """
    if k is not None:
        retrieved_ids = retrieved_ids[:k]

    if len(retrieved_ids) == 0:
        return 0.0

    relevant_set = set(relevant_ids)
    retrieved_set = set(retrieved_ids)

    hits = len(relevant_set.intersection(retrieved_set))
    return hits / len(retrieved_ids)