File size: 3,990 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Contains core ETL functionality to load train/dev datasets.
"""

from typing import Dict
from pathlib import Path
import json
import pandas as pd
from src.etl.types import QAExample
from src.utils.constants import Col, RawField

DEFAULT_ENCODING = "utf-8"


def load_squad_v2_df(file_path: Path) -> pd.DataFrame:
    """
    Loads input SQuAD v2.0 JSON file as a Pandas DF.

    Returns columns:
        - Col.QUESTION_ID.value   : str (unique)
        - Col.TITLE.value         : str
        - Col.CONTEXT.value       : str
        - Col.QUESTION.value      : str
        - Col.IS_IMPOSSIBLE.value : bool
        - Col.ANSWER_TEXTS.value  : List[str]   (all gold answers; [] if impossible)
        - Col.ANSWER_STARTS.value : List[int]   (all start offsets; [] if impossible)
        - Col.NUM_ANSWERS.value   : int         (len(answers))
    """
    assert file_path.exists(), f"File not found: {file_path}"
    with file_path.open("r", encoding=DEFAULT_ENCODING) as f:
        raw = json.load(f)

    assert (
        set(raw.keys()) == {RawField.VERSION.value, RawField.DATA.value}
        and raw[RawField.VERSION.value] == "v2.0"
    ), "Unexpected input data formatting."

    rows = []
    for article in raw[RawField.DATA.value]:
        title = article[Col.TITLE.value]

        for paragraph in article[RawField.PARAGRAPHS.value]:
            context = paragraph[Col.CONTEXT.value]
            for qa in paragraph[RawField.QAS.value]:

                # gold answers (may be empty if unanswerable)
                answers = qa[RawField.ANSWERS.value]
                assert isinstance(answers, list), "Unexpected raw answers type."
                gold_texts = [a[RawField.ANSWER_TEXT.value] for a in answers]
                gold_starts = [a[RawField.ANSWER_START.value] for a in answers]

                # Structural check: lengths must match
                assert len(gold_texts) == len(
                    gold_starts
                ), f"Mismatched gold lengths for {qa[Col.QUESTION_ID.value]}"

                rows.append(
                    {
                        Col.QUESTION_ID.value: qa[Col.QUESTION_ID.value],
                        Col.TITLE.value: title,
                        Col.CONTEXT.value: context,
                        Col.QUESTION.value: qa[Col.QUESTION.value],
                        Col.IS_IMPOSSIBLE.value: bool(qa[Col.IS_IMPOSSIBLE.value]),
                        Col.ANSWER_TEXTS.value: gold_texts,
                        Col.ANSWER_STARTS.value: gold_starts,
                        Col.NUM_ANSWERS.value: len(gold_texts),
                    }
                )
    df = pd.DataFrame(rows)
    assert (
        df[Col.QUESTION_ID.value].duplicated().sum() == 0
    ), "Unexpected non-unique question ID."
    return df


def df_to_examples_map(df: pd.DataFrame) -> Dict[str, QAExample]:
    """
    Convert DF -> Dict[question ID, QAExample].
    Loader already asserted uniqueness of IDs and basic structure.
    """
    required = {
        Col.QUESTION_ID.value,
        Col.TITLE.value,
        Col.CONTEXT.value,
        Col.QUESTION.value,
        Col.IS_IMPOSSIBLE.value,
        Col.ANSWER_TEXTS.value,
        Col.ANSWER_STARTS.value,
    }
    missing = required - set(df.columns)
    assert not missing, f"Missing required columns: {sorted(missing)}"

    ex_map: Dict[str, QAExample] = {}
    for _, row in df.iterrows():
        qid = row[Col.QUESTION_ID.value]
        assert qid not in ex_map, f"Duplicate id during build: {qid}"
        ex_map[qid] = QAExample(
            question_id=qid,
            title=row[Col.TITLE.value],
            question=row[Col.QUESTION.value],
            context=row[Col.CONTEXT.value],
            # avoid accidental shared references - create new list objects
            answer_texts=list(row[Col.ANSWER_TEXTS.value] or []),
            answer_starts=list(row[Col.ANSWER_STARTS.value] or []),
            is_impossible=row[Col.IS_IMPOSSIBLE.value],
        )
    return ex_map