File size: 3,557 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Creates frozen dataclass objects per individual ground-truth example and individual prediction.

Benefits:
    - Instance immutability: avoids accidental changes to data which would be otherwise unexpected
    - Explicit type annotation across object fields, removes ambiguity
    - Compact implementation: reduces boilerplate code (e.g., __init__() is auto-generated)
    - Post-init preserves consistent validation for each and every object created
"""

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Dict


@dataclass(frozen=True)
class QAExample:
    """
    Single QA instance pulled from SQuAD (gold/ground-truth instance) as a
    frozen dataclass to preserve immutability throughout the code's execution.
    As per the official evaluation script, storing all possible gold answers.
    If is_impossible is True then answer_texts and answer_starts are expected to be empty;
    this is guaranteed during __post_init__().
    """

    question_id: str
    title: str
    question: str
    context: str
    answer_texts: List[str]  # empty list when is_impossible is True
    answer_starts: List[int]  # empty list when is_impossible is True
    is_impossible: bool

    def __post_init__(self):
        if not isinstance(self.is_impossible, bool):
            raise ValueError("is_impossible field needs to be of boolean type.")

        if len(self.answer_texts) != len(self.answer_starts):
            raise ValueError(
                "Incompatible sizes of answer_texts/answer_starts of QAExample."
            )
        if self.is_impossible:
            if self.answer_texts or self.answer_starts:
                raise ValueError(
                    "Incompatible configuration between is_impossible (True) Vs answer_texts/answer_starts (non-empty) of QAExample."
                )
        else:
            if not self.answer_texts or not self.answer_starts:
                raise ValueError(
                    "Incompatible configuration between is_impossible (False) Vs answer_texts/answer_starts (empty) of QAExample."
                )


@dataclass(frozen=True)
class Prediction:
    """
    Single model prediction for a question.
    __post_init__() method validates for consistency with expected values.
    """

    question_id: str
    predicted_answer: str  # '' if the model predicts no-answer
    confidence: float  # corresponds to the confidence level that the question is answerable via the context
    is_impossible: bool

    def __post_init__(self):
        if not (0 <= self.confidence <= 1):
            raise ValueError(
                "Confidence of Prediction object should be a probability score [0, 1]."
            )

    @classmethod
    def null(cls, question_id: str, confidence: float = 0.0) -> Prediction:
        """
        No-answer Prediction constructor to standardize it throughout the code.
        """
        return cls(
            question_id=question_id,
            predicted_answer="",
            confidence=confidence,
            is_impossible=True,
        )

    @classmethod
    def flatten_predicted_answers(
        cls, predictions: Dict[str, Prediction]
    ) -> Dict[str, str]:
        """
        Convert Dict[qid, Prediction] -> Dict[qid, str] -
        similar to official evaluation script style.
        """
        # TODO - add an extra check that each key of the Dict matches with the
        #   question ID stored as part of the Prediction object
        return {qid: p.predicted_answer for qid, p in predictions.items()}