|
|
""" |
|
|
Creates frozen dataclass objects per individual ground-truth example and individual prediction. |
|
|
|
|
|
Benefits: |
|
|
- Instance immutability: avoids accidental changes to data which would be otherwise unexpected |
|
|
- Explicit type annotation across object fields, removes ambiguity |
|
|
- Compact implementation: reduces boilerplate code (e.g., __init__() is auto-generated) |
|
|
- Post-init preserves consistent validation for each and every object created |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class QAExample: |
|
|
""" |
|
|
Single QA instance pulled from SQuAD (gold/ground-truth instance) as a |
|
|
frozen dataclass to preserve immutability throughout the code's execution. |
|
|
As per the official evaluation script, storing all possible gold answers. |
|
|
If is_impossible is True then answer_texts and answer_starts are expected to be empty; |
|
|
this is guaranteed during __post_init__(). |
|
|
""" |
|
|
|
|
|
question_id: str |
|
|
title: str |
|
|
question: str |
|
|
context: str |
|
|
answer_texts: List[str] |
|
|
answer_starts: List[int] |
|
|
is_impossible: bool |
|
|
|
|
|
def __post_init__(self): |
|
|
if not isinstance(self.is_impossible, bool): |
|
|
raise ValueError("is_impossible field needs to be of boolean type.") |
|
|
|
|
|
if len(self.answer_texts) != len(self.answer_starts): |
|
|
raise ValueError( |
|
|
"Incompatible sizes of answer_texts/answer_starts of QAExample." |
|
|
) |
|
|
if self.is_impossible: |
|
|
if self.answer_texts or self.answer_starts: |
|
|
raise ValueError( |
|
|
"Incompatible configuration between is_impossible (True) Vs answer_texts/answer_starts (non-empty) of QAExample." |
|
|
) |
|
|
else: |
|
|
if not self.answer_texts or not self.answer_starts: |
|
|
raise ValueError( |
|
|
"Incompatible configuration between is_impossible (False) Vs answer_texts/answer_starts (empty) of QAExample." |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class Prediction: |
|
|
""" |
|
|
Single model prediction for a question. |
|
|
__post_init__() method validates for consistency with expected values. |
|
|
""" |
|
|
|
|
|
question_id: str |
|
|
predicted_answer: str |
|
|
confidence: float |
|
|
is_impossible: bool |
|
|
|
|
|
def __post_init__(self): |
|
|
if not (0 <= self.confidence <= 1): |
|
|
raise ValueError( |
|
|
"Confidence of Prediction object should be a probability score [0, 1]." |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def null(cls, question_id: str, confidence: float = 0.0) -> Prediction: |
|
|
""" |
|
|
No-answer Prediction constructor to standardize it throughout the code. |
|
|
""" |
|
|
return cls( |
|
|
question_id=question_id, |
|
|
predicted_answer="", |
|
|
confidence=confidence, |
|
|
is_impossible=True, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def flatten_predicted_answers( |
|
|
cls, predictions: Dict[str, Prediction] |
|
|
) -> Dict[str, str]: |
|
|
""" |
|
|
Convert Dict[qid, Prediction] -> Dict[qid, str] - |
|
|
similar to official evaluation script style. |
|
|
""" |
|
|
|
|
|
|
|
|
return {qid: p.predicted_answer for qid, p in predictions.items()} |
|
|
|