from __future__ import annotations from typing import Iterable, List, Sequence, Set def recall_at_k(ground_truth: Set[str], preds: Sequence[str], k: int) -> float: if not ground_truth: return 0.0 topk = preds[:k] hits = len(ground_truth.intersection(topk)) return hits / len(ground_truth) def mrr_at_k(ground_truth: Set[str], preds: Sequence[str], k: int) -> float: if not ground_truth: return 0.0 for idx, pid in enumerate(preds[:k], start=1): if pid in ground_truth: return 1.0 / idx return 0.0 def mean_metric(queries: Iterable[Set[str]], preds_list: Iterable[Sequence[str]], fn, k: int) -> float: scores = [] for g, p in zip(queries, preds_list): scores.append(fn(g, p, k)) return sum(scores) / len(scores) if scores else 0.0