File size: 821 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from __future__ import annotations

from typing import Iterable, List, Sequence, Set


def recall_at_k(ground_truth: Set[str], preds: Sequence[str], k: int) -> float:
    if not ground_truth:
        return 0.0
    topk = preds[:k]
    hits = len(ground_truth.intersection(topk))
    return hits / len(ground_truth)


def mrr_at_k(ground_truth: Set[str], preds: Sequence[str], k: int) -> float:
    if not ground_truth:
        return 0.0
    for idx, pid in enumerate(preds[:k], start=1):
        if pid in ground_truth:
            return 1.0 / idx
    return 0.0


def mean_metric(queries: Iterable[Set[str]], preds_list: Iterable[Sequence[str]], fn, k: int) -> float:
    scores = []
    for g, p in zip(queries, preds_list):
        scores.append(fn(g, p, k))
    return sum(scores) / len(scores) if scores else 0.0