|
|
""" |
|
|
Specifies the Evaluator's functionality. |
|
|
Leverages metrics as computed in the official SQuAD v2.0 evaluation |
|
|
script to ensure reporting consistency. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List |
|
|
from src.evaluation.metrics import Metrics |
|
|
from src.etl.types import QAExample, Prediction |
|
|
from src.evaluation.squad_v2_official import ( |
|
|
normalize_answer, |
|
|
compute_exact, |
|
|
compute_f1, |
|
|
) |
|
|
|
|
|
|
|
|
class Evaluator: |
|
|
def evaluate( |
|
|
self, predictions: Dict[str, Prediction], examples: Dict[str, QAExample] |
|
|
) -> Metrics: |
|
|
|
|
|
assert len(examples) > 0, "Examples must be non-empty." |
|
|
assert isinstance(predictions, dict) and isinstance( |
|
|
examples, dict |
|
|
), "Inputs must be dicts." |
|
|
extras = set(predictions.keys()).symmetric_difference(set(examples.keys())) |
|
|
assert ( |
|
|
not extras |
|
|
), f"Differences across predictions/examples question ids: {list(sorted(extras))[:3]} ..." |
|
|
|
|
|
golds: Dict[str, List[str]] = {} |
|
|
for qid, ex in examples.items(): |
|
|
if ex.is_impossible: |
|
|
golds[qid] = [""] |
|
|
else: |
|
|
|
|
|
filtered = [t for t in ex.answer_texts if normalize_answer(str(t))] |
|
|
golds[qid] = filtered if filtered else [""] |
|
|
|
|
|
em_sum = 0.0 |
|
|
f1_sum = 0.0 |
|
|
|
|
|
for qid, gold_list in golds.items(): |
|
|
pred_obj = predictions.get(qid) |
|
|
if not pred_obj: |
|
|
raise ValueError( |
|
|
"Unexpected absence of Prediction object for question ID:%s" % qid |
|
|
) |
|
|
pred_text = pred_obj.predicted_answer |
|
|
assert isinstance(pred_text, str), "Unexpected predicted answer type." |
|
|
|
|
|
best_em = max((compute_exact(g, pred_text) for g in gold_list), default=0) |
|
|
best_f1 = max((compute_f1(g, pred_text) for g in gold_list), default=0.0) |
|
|
|
|
|
em_sum += float(best_em) |
|
|
f1_sum += float(best_f1) |
|
|
|
|
|
total = len(golds) |
|
|
assert total >= 1, "Unexpected empty dict of ground-truth items." |
|
|
return Metrics( |
|
|
exact_score=100.0 * (em_sum / total), |
|
|
f1_score=100.0 * (f1_sum / total), |
|
|
total_num_instances=total, |
|
|
) |
|
|
|