squad2-qa / src /etl /squad_v2_loader.py
Kimis Perros
Initial deployment
461f64f
"""
Contains core ETL functionality to load train/dev datasets.
"""
from typing import Dict
from pathlib import Path
import json
import pandas as pd
from src.etl.types import QAExample
from src.utils.constants import Col, RawField
DEFAULT_ENCODING = "utf-8"
def load_squad_v2_df(file_path: Path) -> pd.DataFrame:
"""
Loads input SQuAD v2.0 JSON file as a Pandas DF.
Returns columns:
- Col.QUESTION_ID.value : str (unique)
- Col.TITLE.value : str
- Col.CONTEXT.value : str
- Col.QUESTION.value : str
- Col.IS_IMPOSSIBLE.value : bool
- Col.ANSWER_TEXTS.value : List[str] (all gold answers; [] if impossible)
- Col.ANSWER_STARTS.value : List[int] (all start offsets; [] if impossible)
- Col.NUM_ANSWERS.value : int (len(answers))
"""
assert file_path.exists(), f"File not found: {file_path}"
with file_path.open("r", encoding=DEFAULT_ENCODING) as f:
raw = json.load(f)
assert (
set(raw.keys()) == {RawField.VERSION.value, RawField.DATA.value}
and raw[RawField.VERSION.value] == "v2.0"
), "Unexpected input data formatting."
rows = []
for article in raw[RawField.DATA.value]:
title = article[Col.TITLE.value]
for paragraph in article[RawField.PARAGRAPHS.value]:
context = paragraph[Col.CONTEXT.value]
for qa in paragraph[RawField.QAS.value]:
# gold answers (may be empty if unanswerable)
answers = qa[RawField.ANSWERS.value]
assert isinstance(answers, list), "Unexpected raw answers type."
gold_texts = [a[RawField.ANSWER_TEXT.value] for a in answers]
gold_starts = [a[RawField.ANSWER_START.value] for a in answers]
# Structural check: lengths must match
assert len(gold_texts) == len(
gold_starts
), f"Mismatched gold lengths for {qa[Col.QUESTION_ID.value]}"
rows.append(
{
Col.QUESTION_ID.value: qa[Col.QUESTION_ID.value],
Col.TITLE.value: title,
Col.CONTEXT.value: context,
Col.QUESTION.value: qa[Col.QUESTION.value],
Col.IS_IMPOSSIBLE.value: bool(qa[Col.IS_IMPOSSIBLE.value]),
Col.ANSWER_TEXTS.value: gold_texts,
Col.ANSWER_STARTS.value: gold_starts,
Col.NUM_ANSWERS.value: len(gold_texts),
}
)
df = pd.DataFrame(rows)
assert (
df[Col.QUESTION_ID.value].duplicated().sum() == 0
), "Unexpected non-unique question ID."
return df
def df_to_examples_map(df: pd.DataFrame) -> Dict[str, QAExample]:
"""
Convert DF -> Dict[question ID, QAExample].
Loader already asserted uniqueness of IDs and basic structure.
"""
required = {
Col.QUESTION_ID.value,
Col.TITLE.value,
Col.CONTEXT.value,
Col.QUESTION.value,
Col.IS_IMPOSSIBLE.value,
Col.ANSWER_TEXTS.value,
Col.ANSWER_STARTS.value,
}
missing = required - set(df.columns)
assert not missing, f"Missing required columns: {sorted(missing)}"
ex_map: Dict[str, QAExample] = {}
for _, row in df.iterrows():
qid = row[Col.QUESTION_ID.value]
assert qid not in ex_map, f"Duplicate id during build: {qid}"
ex_map[qid] = QAExample(
question_id=qid,
title=row[Col.TITLE.value],
question=row[Col.QUESTION.value],
context=row[Col.CONTEXT.value],
# avoid accidental shared references - create new list objects
answer_texts=list(row[Col.ANSWER_TEXTS.value] or []),
answer_starts=list(row[Col.ANSWER_STARTS.value] or []),
is_impossible=row[Col.IS_IMPOSSIBLE.value],
)
return ex_map