""" Specifies the Evaluator's functionality. Leverages metrics as computed in the official SQuAD v2.0 evaluation script to ensure reporting consistency. """ from typing import Dict, List from src.evaluation.metrics import Metrics from src.etl.types import QAExample, Prediction from src.evaluation.squad_v2_official import ( normalize_answer, compute_exact, compute_f1, ) class Evaluator: def evaluate( self, predictions: Dict[str, Prediction], examples: Dict[str, QAExample] ) -> Metrics: assert len(examples) > 0, "Examples must be non-empty." assert isinstance(predictions, dict) and isinstance( examples, dict ), "Inputs must be dicts." extras = set(predictions.keys()).symmetric_difference(set(examples.keys())) assert ( not extras ), f"Differences across predictions/examples question ids: {list(sorted(extras))[:3]} ..." golds: Dict[str, List[str]] = {} for qid, ex in examples.items(): if ex.is_impossible: golds[qid] = [""] else: # similar to the official script - filter out golds which normalize to empty filtered = [t for t in ex.answer_texts if normalize_answer(str(t))] golds[qid] = filtered if filtered else [""] em_sum = 0.0 f1_sum = 0.0 for qid, gold_list in golds.items(): pred_obj = predictions.get(qid) if not pred_obj: raise ValueError( "Unexpected absence of Prediction object for question ID:%s" % qid ) pred_text = pred_obj.predicted_answer assert isinstance(pred_text, str), "Unexpected predicted answer type." best_em = max((compute_exact(g, pred_text) for g in gold_list), default=0) best_f1 = max((compute_f1(g, pred_text) for g in gold_list), default=0.0) em_sum += float(best_em) f1_sum += float(best_f1) total = len(golds) assert total >= 1, "Unexpected empty dict of ground-truth items." return Metrics( exact_score=100.0 * (em_sum / total), f1_score=100.0 * (f1_sum / total), total_num_instances=total, )