| from typing import List, Dict, Any |
| import torch |
| from torchmetrics.functional import ( |
| retrieval_hit_rate, retrieval_reciprocal_rank, retrieval_recall, |
| retrieval_precision, retrieval_average_precision, retrieval_normalized_dcg, |
| retrieval_r_precision |
| ) |
|
|
| class Evaluator: |
| |
| def __init__(self, candidate_ids: List[int], device: str = 'cpu'): |
| """ |
| Initializes the evaluator with the given candidate IDs. |
| |
| Args: |
| candidate_ids (List[int]): List of candidate IDs. |
| """ |
| self.candidate_ids = candidate_ids |
| self.device = device |
|
|
| def __call__(self, |
| pred_dict: Dict[int, float], |
| answer_ids: torch.LongTensor, |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]: |
| """ |
| Evaluates the predictions using the specified metrics. |
| |
| Args: |
| pred_dict (Dict[int, float]): Dictionary of predicted scores. |
| answer_ids (torch.LongTensor): Ground truth answer IDs. |
| metrics (List[str]): List of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', |
| 'precision@k', 'map@k', 'ndcg@k'. |
| |
| Returns: |
| Dict[str, float]: Dictionary of evaluation metrics. |
| """ |
| return self.evaluate(pred_dict, answer_ids, metrics) |
| |
| def evaluate(self, |
| pred_dict: Dict[int, float], |
| answer_ids: torch.LongTensor, |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]: |
| """ |
| Evaluates the predictions using the specified metrics. |
| |
| Args: |
| pred_dict (Dict[int, float]): Dictionary of predicted scores. |
| answer_ids (torch.LongTensor): Ground truth answer IDs. |
| metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', |
| 'precision@k', 'map@k', 'ndcg@k'. |
| |
| Returns: |
| Dict[str, float]: Dictionary of evaluation metrics. |
| """ |
| |
| pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1) |
| pred = torch.FloatTensor(list(pred_dict.values())).view(-1) |
| answer_ids = answer_ids.view(-1) |
|
|
| |
| all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * (min(pred) - 1) |
| all_pred[pred_ids] = pred |
| all_pred = all_pred[self.candidate_ids] |
|
|
| |
| bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool) |
| bool_gd[answer_ids] = True |
| bool_gd = bool_gd[self.candidate_ids] |
|
|
| |
| eval_metrics = {} |
| for metric in metrics: |
| k = int(metric.split('@')[-1]) if '@' in metric else None |
| if metric == 'mrr': |
| result = retrieval_reciprocal_rank(all_pred, bool_gd) |
| elif metric == 'rprecision': |
| result = retrieval_r_precision(all_pred, bool_gd) |
| elif 'hit' in metric: |
| result = retrieval_hit_rate(all_pred, bool_gd, top_k=k) |
| elif 'recall' in metric: |
| result = retrieval_recall(all_pred, bool_gd, top_k=k) |
| elif 'precision' in metric: |
| result = retrieval_precision(all_pred, bool_gd, top_k=k) |
| elif 'map' in metric: |
| result = retrieval_average_precision(all_pred, bool_gd, top_k=k) |
| elif 'ndcg' in metric: |
| result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k) |
| eval_metrics[metric] = float(result) |
|
|
| return eval_metrics |
| |
| def evaluate_batch(self, |
| pred_ids, |
| pred, |
| answer_ids: List[Any], |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]: |
| |
| all_pred = torch.ones((max(self.candidate_ids) + 1, pred.shape[1]), dtype=torch.float) * (pred.min() - 1) |
| all_pred[pred_ids, :] = pred |
| all_pred = all_pred[self.candidate_ids].t().to(self.device) |
|
|
| bool_gd = torch.zeros((max(self.candidate_ids) + 1, pred.shape[1]), dtype=torch.bool) |
| bool_gd[torch.concat(answer_ids), torch.repeat_interleave(torch.arange(len(answer_ids)), torch.tensor(list(map(len, answer_ids))))] = True |
| bool_gd = bool_gd[self.candidate_ids].t().to(self.device) |
|
|
| results = [] |
| for i in range(len(answer_ids)): |
| eval_metrics = {} |
| for metric in metrics: |
| k = int(metric.split('@')[-1]) if '@' in metric else None |
| if metric == 'mrr': |
| result = retrieval_reciprocal_rank(all_pred[i], bool_gd[i]) |
| elif metric == 'rprecision': |
| result = retrieval_r_precision(all_pred[i], bool_gd[i]) |
| elif 'hit' in metric: |
| result = retrieval_hit_rate(all_pred[i], bool_gd[i], top_k=k) |
| elif 'recall' in metric: |
| result = retrieval_recall(all_pred[i], bool_gd[i], top_k=k) |
| elif 'precision' in metric: |
| result = retrieval_precision(all_pred[i], bool_gd[i], top_k=k) |
| elif 'map' in metric: |
| result = retrieval_average_precision(all_pred[i], bool_gd[i], top_k=k) |
| elif 'ndcg' in metric: |
| result = retrieval_normalized_dcg(all_pred[i], bool_gd[i], top_k=k) |
| eval_metrics[metric] = float(result) |
| results.append(eval_metrics) |
| return results |
|
|