|
|
""" |
|
|
Contains core ETL functionality to load train/dev datasets. |
|
|
""" |
|
|
|
|
|
from typing import Dict |
|
|
from pathlib import Path |
|
|
import json |
|
|
import pandas as pd |
|
|
from src.etl.types import QAExample |
|
|
from src.utils.constants import Col, RawField |
|
|
|
|
|
DEFAULT_ENCODING = "utf-8" |
|
|
|
|
|
|
|
|
def load_squad_v2_df(file_path: Path) -> pd.DataFrame: |
|
|
""" |
|
|
Loads input SQuAD v2.0 JSON file as a Pandas DF. |
|
|
|
|
|
Returns columns: |
|
|
- Col.QUESTION_ID.value : str (unique) |
|
|
- Col.TITLE.value : str |
|
|
- Col.CONTEXT.value : str |
|
|
- Col.QUESTION.value : str |
|
|
- Col.IS_IMPOSSIBLE.value : bool |
|
|
- Col.ANSWER_TEXTS.value : List[str] (all gold answers; [] if impossible) |
|
|
- Col.ANSWER_STARTS.value : List[int] (all start offsets; [] if impossible) |
|
|
- Col.NUM_ANSWERS.value : int (len(answers)) |
|
|
""" |
|
|
assert file_path.exists(), f"File not found: {file_path}" |
|
|
with file_path.open("r", encoding=DEFAULT_ENCODING) as f: |
|
|
raw = json.load(f) |
|
|
|
|
|
assert ( |
|
|
set(raw.keys()) == {RawField.VERSION.value, RawField.DATA.value} |
|
|
and raw[RawField.VERSION.value] == "v2.0" |
|
|
), "Unexpected input data formatting." |
|
|
|
|
|
rows = [] |
|
|
for article in raw[RawField.DATA.value]: |
|
|
title = article[Col.TITLE.value] |
|
|
|
|
|
for paragraph in article[RawField.PARAGRAPHS.value]: |
|
|
context = paragraph[Col.CONTEXT.value] |
|
|
for qa in paragraph[RawField.QAS.value]: |
|
|
|
|
|
|
|
|
answers = qa[RawField.ANSWERS.value] |
|
|
assert isinstance(answers, list), "Unexpected raw answers type." |
|
|
gold_texts = [a[RawField.ANSWER_TEXT.value] for a in answers] |
|
|
gold_starts = [a[RawField.ANSWER_START.value] for a in answers] |
|
|
|
|
|
|
|
|
assert len(gold_texts) == len( |
|
|
gold_starts |
|
|
), f"Mismatched gold lengths for {qa[Col.QUESTION_ID.value]}" |
|
|
|
|
|
rows.append( |
|
|
{ |
|
|
Col.QUESTION_ID.value: qa[Col.QUESTION_ID.value], |
|
|
Col.TITLE.value: title, |
|
|
Col.CONTEXT.value: context, |
|
|
Col.QUESTION.value: qa[Col.QUESTION.value], |
|
|
Col.IS_IMPOSSIBLE.value: bool(qa[Col.IS_IMPOSSIBLE.value]), |
|
|
Col.ANSWER_TEXTS.value: gold_texts, |
|
|
Col.ANSWER_STARTS.value: gold_starts, |
|
|
Col.NUM_ANSWERS.value: len(gold_texts), |
|
|
} |
|
|
) |
|
|
df = pd.DataFrame(rows) |
|
|
assert ( |
|
|
df[Col.QUESTION_ID.value].duplicated().sum() == 0 |
|
|
), "Unexpected non-unique question ID." |
|
|
return df |
|
|
|
|
|
|
|
|
def df_to_examples_map(df: pd.DataFrame) -> Dict[str, QAExample]: |
|
|
""" |
|
|
Convert DF -> Dict[question ID, QAExample]. |
|
|
Loader already asserted uniqueness of IDs and basic structure. |
|
|
""" |
|
|
required = { |
|
|
Col.QUESTION_ID.value, |
|
|
Col.TITLE.value, |
|
|
Col.CONTEXT.value, |
|
|
Col.QUESTION.value, |
|
|
Col.IS_IMPOSSIBLE.value, |
|
|
Col.ANSWER_TEXTS.value, |
|
|
Col.ANSWER_STARTS.value, |
|
|
} |
|
|
missing = required - set(df.columns) |
|
|
assert not missing, f"Missing required columns: {sorted(missing)}" |
|
|
|
|
|
ex_map: Dict[str, QAExample] = {} |
|
|
for _, row in df.iterrows(): |
|
|
qid = row[Col.QUESTION_ID.value] |
|
|
assert qid not in ex_map, f"Duplicate id during build: {qid}" |
|
|
ex_map[qid] = QAExample( |
|
|
question_id=qid, |
|
|
title=row[Col.TITLE.value], |
|
|
question=row[Col.QUESTION.value], |
|
|
context=row[Col.CONTEXT.value], |
|
|
|
|
|
answer_texts=list(row[Col.ANSWER_TEXTS.value] or []), |
|
|
answer_starts=list(row[Col.ANSWER_STARTS.value] or []), |
|
|
is_impossible=row[Col.IS_IMPOSSIBLE.value], |
|
|
) |
|
|
return ex_map |
|
|
|