Kimis Perros commited on
Commit
461f64f
·
0 Parent(s):

Initial deployment

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # --- Python Cache ---
2
+ __pycache__/
3
+ .pytest_cache/
4
+ *.py[cod]
5
+
6
+ # --- OS-specific files ---
7
+ .DS_Store
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SQuAD 2.0 QA System
3
+ colorFrom: blue
4
+ colorTo: green
5
+ sdk: gradio
6
+ sdk_version: 4.0.0
7
+ app_file: app.py
8
+ ---
9
+
10
+ # SQuAD 2.0 Question Answering System
11
+
12
+ ## Model Details
13
+ - **General-Purpose Pre-Trained Model**: bert-base-uncased
14
+ - **Training Dataset**: SQuAD 2.0 (~130K examples)
15
+ - **Performance**: >70% F1 score on dev set
16
+ - **Capabilities**: Handles both answerable and unanswerable questions
17
+
18
+ ## Usage
19
+ Provide a context paragraph and ask a question to extract the answer.
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Question Answering System trained on SQuAD 2.0
3
+ """
4
+
5
+ import gradio as gr
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add parent directory to Python path so as to load 'src' module
10
+ current_dir = Path(__file__).parent
11
+ sys.path.insert(0, str(current_dir))
12
+
13
+ from src.models.bert_based_model import BertBasedQAModel
14
+ from src.config.model_configs import OriginalBertQAConfig
15
+ from src.etl.types import QAExample
16
+
17
+ model = BertBasedQAModel.load_from_experiment(
18
+ experiment_dir=Path("checkpoint"), config_class=OriginalBertQAConfig, device="cpu"
19
+ )
20
+
21
+
22
+ def answer_question(context: str, question: str) -> str:
23
+ """Process QA request and return answer."""
24
+ if not context.strip():
25
+ return "Please provide context text."
26
+ if not question.strip():
27
+ return "Please provide a question."
28
+
29
+ try:
30
+ example = QAExample(
31
+ question_id="demo",
32
+ title="Demo",
33
+ question=question.strip(),
34
+ context=context.strip(),
35
+ answer_texts=[],
36
+ answer_starts=[],
37
+ is_impossible=False,
38
+ )
39
+
40
+ predictions = model.predict({"demo": example})
41
+ answer = predictions["demo"].predicted_answer
42
+
43
+ return answer if answer else "No answer found."
44
+
45
+ except Exception as e:
46
+ return f"Error: {str(e)}"
47
+
48
+
49
+ demo = gr.Interface(
50
+ fn=answer_question,
51
+ inputs=[
52
+ gr.Textbox(lines=8, placeholder="Enter context paragraph...", label="Context"),
53
+ gr.Textbox(placeholder="Enter your question...", label="Question"),
54
+ ],
55
+ outputs=gr.Textbox(label="Answer", show_copy_button=True),
56
+ title="SQuAD 2.0 Question Answering",
57
+ description="BERT-base model fine-tuned on SQuAD 2.0 dataset",
58
+ allow_flagging="never",
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ demo.launch()
checkpoint/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backbone_name": "bert-base-uncased",
3
+ "max_sequence_length": 384,
4
+ "learning_rate": 5e-05,
5
+ "num_epochs": 2,
6
+ "batch_size": 48,
7
+ "eval_batch_size": 1024,
8
+ "no_answer_threshold": 0.0,
9
+ "device": "cuda"
10
+ }
checkpoint/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64535ce08c77f38ce4243a75daa6ac4696de0999319fb1fb6d8c6550ed18ba2a
3
+ size 438019655
checkpoint/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
checkpoint/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
checkpoint/tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ transformers==4.57.0
3
+ gradio==4.0.0
4
+ pandas==2.3.3
5
+ numpy==2.2.6
src/__init__.py ADDED
File without changes
src/config/__init__.py ADDED
File without changes
src/config/model_configs.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Immutable configurations enabling to share common fields across the specific models used.
3
+ """
4
+
5
+ from abc import ABC
6
+ from dataclasses import dataclass
7
+ from typing import ClassVar
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class BaseModelConfig(ABC):
12
+ """
13
+ Container storing configurations useful across all QA models.
14
+ """
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class AlwaysNoAnswerModelConfig(BaseModelConfig):
19
+ """
20
+ Trivial baseline that always predicts no-answer ("").
21
+ """
22
+
23
+ MODEL_TYPE: ClassVar[str] = "always_no_answer"
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class SentenceEmbeddingModelConfig(BaseModelConfig):
28
+ """
29
+ Config object for the simpler baseline model.
30
+ """
31
+
32
+ # Ensuring that MODEL_TYPE is not treated as an object field (e.g., not added to __eq__() etc.)
33
+ # as it is common across all objects of the dataclass
34
+ MODEL_TYPE: ClassVar[str] = "embedding_best_sentence"
35
+ # TODO - consider switching to other defaults for non-Apple users
36
+ device: str = "mps"
37
+ sentence_model_name: str = "all-MiniLM-L6-v2"
38
+ no_answer_threshold: float = 0.5
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class BertQAConfig(BaseModelConfig, ABC):
43
+ """
44
+ Shared super-class config to be sub-classed by BERT model variants.
45
+ """
46
+
47
+ # Specifying fields to be materialized by sub-classes to avoid Pylance complaints
48
+ backbone_name: str
49
+ max_sequence_length: int
50
+ learning_rate: float
51
+ num_epochs: int
52
+ batch_size: int
53
+ eval_batch_size: int
54
+ no_answer_threshold: float
55
+ device: str = "cuda"
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class TinyBertQAConfig(BertQAConfig):
60
+ """
61
+ Config for a Tiny BERT-based extractive QA system.
62
+ """
63
+
64
+ MODEL_TYPE: ClassVar[str] = "tinybert_qa"
65
+ backbone_name: str = (
66
+ "huawei-noah/TinyBERT_General_4L_312D" # General-purpose checkpoint (not QA-tuned)
67
+ )
68
+ max_sequence_length: int = 256
69
+ learning_rate: float = 2e-5
70
+ num_epochs: int = 5
71
+ batch_size: int = 64
72
+ eval_batch_size: int = 2048
73
+ no_answer_threshold: float = 0.0
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class OriginalBertQAConfig(BertQAConfig):
78
+ """
79
+ Config for a BERT-based extractive QA system (original BERT model).
80
+ """
81
+
82
+ MODEL_TYPE: ClassVar[str] = "original_bert_qa"
83
+ backbone_name: str = (
84
+ "bert-base-uncased" # General-purpose checkpoint (not QA-tuned)
85
+ )
86
+ max_sequence_length: int = 384
87
+ learning_rate: float = 5e-5
88
+ num_epochs: int = 2
89
+ batch_size: int = 48
90
+ eval_batch_size: int = 1024
91
+ no_answer_threshold: float = 0.5
src/etl/__init__.py ADDED
File without changes
src/etl/squad_v2_loader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains core ETL functionality to load train/dev datasets.
3
+ """
4
+
5
+ from typing import Dict
6
+ from pathlib import Path
7
+ import json
8
+ import pandas as pd
9
+ from src.etl.types import QAExample
10
+ from src.utils.constants import Col, RawField
11
+
12
+ DEFAULT_ENCODING = "utf-8"
13
+
14
+
15
+ def load_squad_v2_df(file_path: Path) -> pd.DataFrame:
16
+ """
17
+ Loads input SQuAD v2.0 JSON file as a Pandas DF.
18
+
19
+ Returns columns:
20
+ - Col.QUESTION_ID.value : str (unique)
21
+ - Col.TITLE.value : str
22
+ - Col.CONTEXT.value : str
23
+ - Col.QUESTION.value : str
24
+ - Col.IS_IMPOSSIBLE.value : bool
25
+ - Col.ANSWER_TEXTS.value : List[str] (all gold answers; [] if impossible)
26
+ - Col.ANSWER_STARTS.value : List[int] (all start offsets; [] if impossible)
27
+ - Col.NUM_ANSWERS.value : int (len(answers))
28
+ """
29
+ assert file_path.exists(), f"File not found: {file_path}"
30
+ with file_path.open("r", encoding=DEFAULT_ENCODING) as f:
31
+ raw = json.load(f)
32
+
33
+ assert (
34
+ set(raw.keys()) == {RawField.VERSION.value, RawField.DATA.value}
35
+ and raw[RawField.VERSION.value] == "v2.0"
36
+ ), "Unexpected input data formatting."
37
+
38
+ rows = []
39
+ for article in raw[RawField.DATA.value]:
40
+ title = article[Col.TITLE.value]
41
+
42
+ for paragraph in article[RawField.PARAGRAPHS.value]:
43
+ context = paragraph[Col.CONTEXT.value]
44
+ for qa in paragraph[RawField.QAS.value]:
45
+
46
+ # gold answers (may be empty if unanswerable)
47
+ answers = qa[RawField.ANSWERS.value]
48
+ assert isinstance(answers, list), "Unexpected raw answers type."
49
+ gold_texts = [a[RawField.ANSWER_TEXT.value] for a in answers]
50
+ gold_starts = [a[RawField.ANSWER_START.value] for a in answers]
51
+
52
+ # Structural check: lengths must match
53
+ assert len(gold_texts) == len(
54
+ gold_starts
55
+ ), f"Mismatched gold lengths for {qa[Col.QUESTION_ID.value]}"
56
+
57
+ rows.append(
58
+ {
59
+ Col.QUESTION_ID.value: qa[Col.QUESTION_ID.value],
60
+ Col.TITLE.value: title,
61
+ Col.CONTEXT.value: context,
62
+ Col.QUESTION.value: qa[Col.QUESTION.value],
63
+ Col.IS_IMPOSSIBLE.value: bool(qa[Col.IS_IMPOSSIBLE.value]),
64
+ Col.ANSWER_TEXTS.value: gold_texts,
65
+ Col.ANSWER_STARTS.value: gold_starts,
66
+ Col.NUM_ANSWERS.value: len(gold_texts),
67
+ }
68
+ )
69
+ df = pd.DataFrame(rows)
70
+ assert (
71
+ df[Col.QUESTION_ID.value].duplicated().sum() == 0
72
+ ), "Unexpected non-unique question ID."
73
+ return df
74
+
75
+
76
+ def df_to_examples_map(df: pd.DataFrame) -> Dict[str, QAExample]:
77
+ """
78
+ Convert DF -> Dict[question ID, QAExample].
79
+ Loader already asserted uniqueness of IDs and basic structure.
80
+ """
81
+ required = {
82
+ Col.QUESTION_ID.value,
83
+ Col.TITLE.value,
84
+ Col.CONTEXT.value,
85
+ Col.QUESTION.value,
86
+ Col.IS_IMPOSSIBLE.value,
87
+ Col.ANSWER_TEXTS.value,
88
+ Col.ANSWER_STARTS.value,
89
+ }
90
+ missing = required - set(df.columns)
91
+ assert not missing, f"Missing required columns: {sorted(missing)}"
92
+
93
+ ex_map: Dict[str, QAExample] = {}
94
+ for _, row in df.iterrows():
95
+ qid = row[Col.QUESTION_ID.value]
96
+ assert qid not in ex_map, f"Duplicate id during build: {qid}"
97
+ ex_map[qid] = QAExample(
98
+ question_id=qid,
99
+ title=row[Col.TITLE.value],
100
+ question=row[Col.QUESTION.value],
101
+ context=row[Col.CONTEXT.value],
102
+ # avoid accidental shared references - create new list objects
103
+ answer_texts=list(row[Col.ANSWER_TEXTS.value] or []),
104
+ answer_starts=list(row[Col.ANSWER_STARTS.value] or []),
105
+ is_impossible=row[Col.IS_IMPOSSIBLE.value],
106
+ )
107
+ return ex_map
src/etl/types.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates frozen dataclass objects per individual ground-truth example and individual prediction.
3
+
4
+ Benefits:
5
+ - Instance immutability: avoids accidental changes to data which would be otherwise unexpected
6
+ - Explicit type annotation across object fields, removes ambiguity
7
+ - Compact implementation: reduces boilerplate code (e.g., __init__() is auto-generated)
8
+ - Post-init preserves consistent validation for each and every object created
9
+ """
10
+
11
+ from __future__ import annotations
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class QAExample:
18
+ """
19
+ Single QA instance pulled from SQuAD (gold/ground-truth instance) as a
20
+ frozen dataclass to preserve immutability throughout the code's execution.
21
+ As per the official evaluation script, storing all possible gold answers.
22
+ If is_impossible is True then answer_texts and answer_starts are expected to be empty;
23
+ this is guaranteed during __post_init__().
24
+ """
25
+
26
+ question_id: str
27
+ title: str
28
+ question: str
29
+ context: str
30
+ answer_texts: List[str] # empty list when is_impossible is True
31
+ answer_starts: List[int] # empty list when is_impossible is True
32
+ is_impossible: bool
33
+
34
+ def __post_init__(self):
35
+ if not isinstance(self.is_impossible, bool):
36
+ raise ValueError("is_impossible field needs to be of boolean type.")
37
+
38
+ if len(self.answer_texts) != len(self.answer_starts):
39
+ raise ValueError(
40
+ "Incompatible sizes of answer_texts/answer_starts of QAExample."
41
+ )
42
+ if self.is_impossible:
43
+ if self.answer_texts or self.answer_starts:
44
+ raise ValueError(
45
+ "Incompatible configuration between is_impossible (True) Vs answer_texts/answer_starts (non-empty) of QAExample."
46
+ )
47
+ else:
48
+ if not self.answer_texts or not self.answer_starts:
49
+ raise ValueError(
50
+ "Incompatible configuration between is_impossible (False) Vs answer_texts/answer_starts (empty) of QAExample."
51
+ )
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class Prediction:
56
+ """
57
+ Single model prediction for a question.
58
+ __post_init__() method validates for consistency with expected values.
59
+ """
60
+
61
+ question_id: str
62
+ predicted_answer: str # '' if the model predicts no-answer
63
+ confidence: float # corresponds to the confidence level that the question is answerable via the context
64
+ is_impossible: bool
65
+
66
+ def __post_init__(self):
67
+ if not (0 <= self.confidence <= 1):
68
+ raise ValueError(
69
+ "Confidence of Prediction object should be a probability score [0, 1]."
70
+ )
71
+
72
+ @classmethod
73
+ def null(cls, question_id: str, confidence: float = 0.0) -> Prediction:
74
+ """
75
+ No-answer Prediction constructor to standardize it throughout the code.
76
+ """
77
+ return cls(
78
+ question_id=question_id,
79
+ predicted_answer="",
80
+ confidence=confidence,
81
+ is_impossible=True,
82
+ )
83
+
84
+ @classmethod
85
+ def flatten_predicted_answers(
86
+ cls, predictions: Dict[str, Prediction]
87
+ ) -> Dict[str, str]:
88
+ """
89
+ Convert Dict[qid, Prediction] -> Dict[qid, str] -
90
+ similar to official evaluation script style.
91
+ """
92
+ # TODO - add an extra check that each key of the Dict matches with the
93
+ # question ID stored as part of the Prediction object
94
+ return {qid: p.predicted_answer for qid, p in predictions.items()}
src/evaluation/__init__.py ADDED
File without changes
src/evaluation/evaluator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Specifies the Evaluator's functionality.
3
+ Leverages metrics as computed in the official SQuAD v2.0 evaluation
4
+ script to ensure reporting consistency.
5
+ """
6
+
7
+ from typing import Dict, List
8
+ from src.evaluation.metrics import Metrics
9
+ from src.etl.types import QAExample, Prediction
10
+ from src.evaluation.squad_v2_official import (
11
+ normalize_answer,
12
+ compute_exact,
13
+ compute_f1,
14
+ )
15
+
16
+
17
+ class Evaluator:
18
+ def evaluate(
19
+ self, predictions: Dict[str, Prediction], examples: Dict[str, QAExample]
20
+ ) -> Metrics:
21
+
22
+ assert len(examples) > 0, "Examples must be non-empty."
23
+ assert isinstance(predictions, dict) and isinstance(
24
+ examples, dict
25
+ ), "Inputs must be dicts."
26
+ extras = set(predictions.keys()).symmetric_difference(set(examples.keys()))
27
+ assert (
28
+ not extras
29
+ ), f"Differences across predictions/examples question ids: {list(sorted(extras))[:3]} ..."
30
+
31
+ golds: Dict[str, List[str]] = {}
32
+ for qid, ex in examples.items():
33
+ if ex.is_impossible:
34
+ golds[qid] = [""]
35
+ else:
36
+ # similar to the official script - filter out golds which normalize to empty
37
+ filtered = [t for t in ex.answer_texts if normalize_answer(str(t))]
38
+ golds[qid] = filtered if filtered else [""]
39
+
40
+ em_sum = 0.0
41
+ f1_sum = 0.0
42
+
43
+ for qid, gold_list in golds.items():
44
+ pred_obj = predictions.get(qid)
45
+ if not pred_obj:
46
+ raise ValueError(
47
+ "Unexpected absence of Prediction object for question ID:%s" % qid
48
+ )
49
+ pred_text = pred_obj.predicted_answer
50
+ assert isinstance(pred_text, str), "Unexpected predicted answer type."
51
+
52
+ best_em = max((compute_exact(g, pred_text) for g in gold_list), default=0)
53
+ best_f1 = max((compute_f1(g, pred_text) for g in gold_list), default=0.0)
54
+
55
+ em_sum += float(best_em)
56
+ f1_sum += float(best_f1)
57
+
58
+ total = len(golds)
59
+ assert total >= 1, "Unexpected empty dict of ground-truth items."
60
+ return Metrics(
61
+ exact_score=100.0 * (em_sum / total),
62
+ f1_score=100.0 * (f1_sum / total),
63
+ total_num_instances=total,
64
+ )
src/evaluation/inspect_scores.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains supplementary routines for post-hoc validation/inspection of the results:
3
+ - Additional safeguard that dev set results are reliable (external recomputation of F1/EM metrics).
4
+ - Offers example-level inspection to users.
5
+ """
6
+
7
+ import json
8
+ import pandas as pd
9
+ from pathlib import Path
10
+ from src.utils.constants import Col
11
+ from src.evaluation.squad_v2_official import normalize_answer, compute_exact, compute_f1
12
+
13
+
14
+ def validate_experiment(exp_dir: Path, df: pd.DataFrame) -> pd.DataFrame:
15
+ """Load predictions, compute scores, validate against saved metrics."""
16
+ exp_dir = Path(exp_dir)
17
+ # Load and merge predictions
18
+ preds = json.loads((exp_dir / "predictions.json").read_text())
19
+ pred_series = pd.Series(preds, name="predicted_answer")
20
+
21
+ df_eval = df.set_index(Col.QUESTION_ID.value).join(pred_series)
22
+ assert df_eval["predicted_answer"].isna().sum() == 0, "Missing predictions"
23
+
24
+ df_eval = _compute_scores(df_eval)
25
+ computed_em = 100.0 * df_eval["em_score"].mean()
26
+ computed_f1 = 100.0 * df_eval["f1_score"].mean()
27
+
28
+ # Compare with saved
29
+ saved = json.loads((exp_dir / "metrics.json").read_text())
30
+ saved_em, saved_f1 = saved["exact_score"], saved["f1_score"]
31
+
32
+ print(f"\n{exp_dir.name}")
33
+ print(f"Computed: EM={computed_em:.2f}%, F1={computed_f1:.2f}%")
34
+ print(f"Saved: EM={saved_em:.2f}%, F1={saved_f1:.2f}%")
35
+ if abs(computed_em - saved_em) < 0.01 and abs(computed_f1 - saved_f1) < 0.01:
36
+ print("MATCH\n")
37
+ else:
38
+ print("MISMATCH - check evaluation\n")
39
+ return df_eval
40
+
41
+
42
+ def _compute_scores(df: pd.DataFrame) -> pd.DataFrame:
43
+ """Adds em_score and f1_score columns."""
44
+ scores = []
45
+ for _, row in df.iterrows():
46
+ golds = row[Col.ANSWER_TEXTS.value]
47
+ pred = row["predicted_answer"]
48
+
49
+ if not golds:
50
+ golds = [""]
51
+ else:
52
+ golds = [g for g in golds if normalize_answer(str(g))] or [""]
53
+
54
+ em = max((compute_exact(g, pred) for g in golds), default=0)
55
+ f1 = max((compute_f1(g, pred) for g in golds), default=0.0)
56
+ scores.append((em, f1))
57
+
58
+ df = df.copy()
59
+ df["em_score"] = [s[0] for s in scores]
60
+ df["f1_score"] = [s[1] for s in scores]
61
+ return df
src/evaluation/metrics.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightweight Metrics container.
3
+
4
+ Benefits:
5
+ - Facilitates addition/removal of fields without breaking callers.
6
+ - Better isolation of responsibilities around code exporting metrics for experiment tracking.
7
+ """
8
+
9
+ from dataclasses import dataclass, asdict
10
+ from typing import Any, Dict
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class Metrics:
15
+ # Minimal required fields (aligns with official script's main ones)
16
+ exact_score: float
17
+ f1_score: float
18
+ total_num_instances: int
19
+
20
+ def export_for_exp_tracking(self) -> Dict[str, Any]:
21
+ """
22
+ Export a dict for experiment artifacts. Skips keys that are None.
23
+ """
24
+ raw = asdict(self)
25
+ return {k: v for k, v in raw.items() if v is not None}
src/evaluation/squad_v2_official.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Official evaluation script for SQuAD version 2.0.
2
+
3
+ In addition to basic functionality, we also compute additional statistics and
4
+ plot precision-recall curves if an additional na_prob.json file is provided.
5
+ This file is expected to map question ID's to the model's predicted probability
6
+ that a question is unanswerable.
7
+
8
+ TODO: Preserve only functions used in prod (i.e., metrics).
9
+ The full file is temporaririly maintained to ensure parity between
10
+ the official evaluation script Vs in-house prod metrics.
11
+ """
12
+
13
+ import argparse
14
+ import collections
15
+ import json
16
+ import numpy as np
17
+ import os
18
+ import re
19
+ import string
20
+ import sys
21
+
22
+ OPTS = None
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(
27
+ "Official evaluation script for SQuAD version 2.0."
28
+ )
29
+ parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
30
+ parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
31
+ parser.add_argument(
32
+ "--out-file",
33
+ "-o",
34
+ metavar="eval.json",
35
+ help="Write accuracy metrics to file (default is stdout).",
36
+ )
37
+ parser.add_argument(
38
+ "--na-prob-file",
39
+ "-n",
40
+ metavar="na_prob.json",
41
+ help="Model estimates of probability of no answer.",
42
+ )
43
+ parser.add_argument(
44
+ "--na-prob-thresh",
45
+ "-t",
46
+ type=float,
47
+ default=1.0,
48
+ help='Predict "" if no-answer probability exceeds this (default = 1.0).',
49
+ )
50
+ parser.add_argument(
51
+ "--out-image-dir",
52
+ "-p",
53
+ metavar="out_images",
54
+ default=None,
55
+ help="Save precision-recall curves to directory.",
56
+ )
57
+ parser.add_argument("--verbose", "-v", action="store_true")
58
+ if len(sys.argv) == 1:
59
+ parser.print_help()
60
+ sys.exit(1)
61
+ return parser.parse_args()
62
+
63
+
64
+ def make_qid_to_has_ans(dataset):
65
+ qid_to_has_ans = {}
66
+ for article in dataset:
67
+ for p in article["paragraphs"]:
68
+ for qa in p["qas"]:
69
+ qid_to_has_ans[qa["id"]] = bool(qa["answers"])
70
+ return qid_to_has_ans
71
+
72
+
73
+ def normalize_answer(s):
74
+ """Lower text and remove punctuation, articles and extra whitespace."""
75
+
76
+ def remove_articles(text):
77
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
78
+ return re.sub(regex, " ", text)
79
+
80
+ def white_space_fix(text):
81
+ return " ".join(text.split())
82
+
83
+ def remove_punc(text):
84
+ exclude = set(string.punctuation)
85
+ return "".join(ch for ch in text if ch not in exclude)
86
+
87
+ def lower(text):
88
+ return text.lower()
89
+
90
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
91
+
92
+
93
+ def get_tokens(s):
94
+ if not s:
95
+ return []
96
+ return normalize_answer(s).split()
97
+
98
+
99
+ def compute_exact(a_gold, a_pred):
100
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
101
+
102
+
103
+ def compute_f1(a_gold, a_pred):
104
+ gold_toks = get_tokens(a_gold)
105
+ pred_toks = get_tokens(a_pred)
106
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
107
+ num_same = sum(common.values())
108
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
109
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
110
+ return int(gold_toks == pred_toks)
111
+ if num_same == 0:
112
+ return 0
113
+ precision = 1.0 * num_same / len(pred_toks)
114
+ recall = 1.0 * num_same / len(gold_toks)
115
+ f1 = (2 * precision * recall) / (precision + recall)
116
+ return f1
117
+
118
+
119
+ def get_raw_scores(dataset, preds):
120
+ exact_scores = {}
121
+ f1_scores = {}
122
+ for article in dataset:
123
+ for p in article["paragraphs"]:
124
+ for qa in p["qas"]:
125
+ qid = qa["id"]
126
+ gold_answers = [
127
+ a["text"] for a in qa["answers"] if normalize_answer(a["text"])
128
+ ]
129
+ if not gold_answers:
130
+ # For unanswerable questions, only correct answer is empty string
131
+ gold_answers = [""]
132
+ if qid not in preds:
133
+ print("Missing prediction for %s" % qid)
134
+ continue
135
+ a_pred = preds[qid]
136
+ # Take max over all gold answers
137
+ exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
138
+ f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
139
+ return exact_scores, f1_scores
140
+
141
+
142
+ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
143
+ new_scores = {}
144
+ for qid, s in scores.items():
145
+ pred_na = na_probs[qid] > na_prob_thresh
146
+ if pred_na:
147
+ new_scores[qid] = float(not qid_to_has_ans[qid])
148
+ else:
149
+ new_scores[qid] = s
150
+ return new_scores
151
+
152
+
153
+ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
154
+ if not qid_list:
155
+ total = len(exact_scores)
156
+ return collections.OrderedDict(
157
+ [
158
+ ("exact", 100.0 * sum(exact_scores.values()) / total),
159
+ ("f1", 100.0 * sum(f1_scores.values()) / total),
160
+ ("total", total),
161
+ ]
162
+ )
163
+ else:
164
+ total = len(qid_list)
165
+ return collections.OrderedDict(
166
+ [
167
+ ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
168
+ ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
169
+ ("total", total),
170
+ ]
171
+ )
172
+
173
+
174
+ def merge_eval(main_eval, new_eval, prefix):
175
+ for k in new_eval:
176
+ main_eval["%s_%s" % (prefix, k)] = new_eval[k]
177
+
178
+
179
+ def plot_pr_curve(precisions, recalls, out_image, title):
180
+ plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
181
+ plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
182
+ plt.xlabel("Recall")
183
+ plt.ylabel("Precision")
184
+ plt.xlim([0.0, 1.05])
185
+ plt.ylim([0.0, 1.05])
186
+ plt.title(title)
187
+ plt.savefig(out_image)
188
+ plt.clf()
189
+
190
+
191
+ def make_precision_recall_eval(
192
+ scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None
193
+ ):
194
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
195
+ true_pos = 0.0
196
+ cur_p = 1.0
197
+ cur_r = 0.0
198
+ precisions = [1.0]
199
+ recalls = [0.0]
200
+ avg_prec = 0.0
201
+ for i, qid in enumerate(qid_list):
202
+ if qid_to_has_ans[qid]:
203
+ true_pos += scores[qid]
204
+ cur_p = true_pos / float(i + 1)
205
+ cur_r = true_pos / float(num_true_pos)
206
+ if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
207
+ # i.e., if we can put a threshold after this point
208
+ avg_prec += cur_p * (cur_r - recalls[-1])
209
+ precisions.append(cur_p)
210
+ recalls.append(cur_r)
211
+ if out_image:
212
+ plot_pr_curve(precisions, recalls, out_image, title)
213
+ return {"ap": 100.0 * avg_prec}
214
+
215
+
216
+ def run_precision_recall_analysis(
217
+ main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir
218
+ ):
219
+ if out_image_dir and not os.path.exists(out_image_dir):
220
+ os.makedirs(out_image_dir)
221
+ num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
222
+ if num_true_pos == 0:
223
+ return
224
+ pr_exact = make_precision_recall_eval(
225
+ exact_raw,
226
+ na_probs,
227
+ num_true_pos,
228
+ qid_to_has_ans,
229
+ out_image=os.path.join(out_image_dir, "pr_exact.png"),
230
+ title="Precision-Recall curve for Exact Match score",
231
+ )
232
+ pr_f1 = make_precision_recall_eval(
233
+ f1_raw,
234
+ na_probs,
235
+ num_true_pos,
236
+ qid_to_has_ans,
237
+ out_image=os.path.join(out_image_dir, "pr_f1.png"),
238
+ title="Precision-Recall curve for F1 score",
239
+ )
240
+ oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
241
+ pr_oracle = make_precision_recall_eval(
242
+ oracle_scores,
243
+ na_probs,
244
+ num_true_pos,
245
+ qid_to_has_ans,
246
+ out_image=os.path.join(out_image_dir, "pr_oracle.png"),
247
+ title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
248
+ )
249
+ merge_eval(main_eval, pr_exact, "pr_exact")
250
+ merge_eval(main_eval, pr_f1, "pr_f1")
251
+ merge_eval(main_eval, pr_oracle, "pr_oracle")
252
+
253
+
254
+ def histogram_na_prob(na_probs, qid_list, image_dir, name):
255
+ if not qid_list:
256
+ return
257
+ x = [na_probs[k] for k in qid_list]
258
+ weights = np.ones_like(x) / float(len(x))
259
+ plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
260
+ plt.xlabel("Model probability of no-answer")
261
+ plt.ylabel("Proportion of dataset")
262
+ plt.title("Histogram of no-answer probability: %s" % name)
263
+ plt.savefig(os.path.join(image_dir, "na_prob_hist_%s.png" % name))
264
+ plt.clf()
265
+
266
+
267
+ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
268
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
269
+ cur_score = num_no_ans
270
+ best_score = cur_score
271
+ best_thresh = 0.0
272
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
273
+ for i, qid in enumerate(qid_list):
274
+ if qid not in scores:
275
+ continue
276
+ if qid_to_has_ans[qid]:
277
+ diff = scores[qid]
278
+ else:
279
+ if preds[qid]:
280
+ diff = -1
281
+ else:
282
+ diff = 0
283
+ cur_score += diff
284
+ if cur_score > best_score:
285
+ best_score = cur_score
286
+ best_thresh = na_probs[qid]
287
+ return 100.0 * best_score / len(scores), best_thresh
288
+
289
+
290
+ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
291
+ best_exact, exact_thresh = find_best_thresh(
292
+ preds, exact_raw, na_probs, qid_to_has_ans
293
+ )
294
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
295
+ main_eval["best_exact"] = best_exact
296
+ main_eval["best_exact_thresh"] = exact_thresh
297
+ main_eval["best_f1"] = best_f1
298
+ main_eval["best_f1_thresh"] = f1_thresh
299
+
300
+
301
+ def main():
302
+ with open(OPTS.data_file) as f:
303
+ dataset_json = json.load(f)
304
+ dataset = dataset_json["data"]
305
+ with open(OPTS.pred_file) as f:
306
+ preds = json.load(f)
307
+ if OPTS.na_prob_file:
308
+ with open(OPTS.na_prob_file) as f:
309
+ na_probs = json.load(f)
310
+ else:
311
+ na_probs = {k: 0.0 for k in preds}
312
+ qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
313
+ has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
314
+ no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
315
+ exact_raw, f1_raw = get_raw_scores(dataset, preds)
316
+ exact_thresh = apply_no_ans_threshold(
317
+ exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh
318
+ )
319
+ f1_thresh = apply_no_ans_threshold(
320
+ f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh
321
+ )
322
+ out_eval = make_eval_dict(exact_thresh, f1_thresh)
323
+ if has_ans_qids:
324
+ has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
325
+ merge_eval(out_eval, has_ans_eval, "HasAns")
326
+ if no_ans_qids:
327
+ no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
328
+ merge_eval(out_eval, no_ans_eval, "NoAns")
329
+ if OPTS.na_prob_file:
330
+ find_all_best_thresh(
331
+ out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans
332
+ )
333
+ if OPTS.na_prob_file and OPTS.out_image_dir:
334
+ run_precision_recall_analysis(
335
+ out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir
336
+ )
337
+ histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
338
+ histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
339
+ if OPTS.out_file:
340
+ with open(OPTS.out_file, "w") as f:
341
+ json.dump(out_eval, f)
342
+ else:
343
+ print(json.dumps(out_eval, indent=2))
344
+
345
+
346
+ if __name__ == "__main__":
347
+ OPTS = parse_args()
348
+ if OPTS.out_image_dir:
349
+ import matplotlib
350
+
351
+ matplotlib.use("Agg")
352
+ import matplotlib.pyplot as plt
353
+ main()
src/models/__init__.py ADDED
File without changes
src/models/always_no_answer_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Always-no-answer baseline: returns a standardized null Prediction for every question.
3
+ """
4
+
5
+ from typing import Dict, Optional
6
+ from src.models.base_qa_model import QAModel
7
+ from src.etl.types import QAExample, Prediction
8
+ from src.config.model_configs import AlwaysNoAnswerModelConfig
9
+
10
+
11
+ class AlwaysNoAnswerQAModel(QAModel):
12
+ """
13
+ Minimal baseline that predicts "" (no-answer) for all inputs.
14
+ """
15
+
16
+ def __init__(self, config: AlwaysNoAnswerModelConfig) -> None:
17
+ super().__init__()
18
+ assert isinstance(
19
+ config, AlwaysNoAnswerModelConfig
20
+ ), "Incompatible configuration object."
21
+ self.config = config
22
+
23
+ def train(
24
+ self,
25
+ train_examples: Optional[Dict[str, QAExample]] = None,
26
+ val_examples: Optional[Dict[str, QAExample]] = None,
27
+ ) -> None:
28
+ """
29
+ Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
30
+ """
31
+ return
32
+
33
+ def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
34
+ assert isinstance(examples, dict), "Incompatible input examples type."
35
+ return {qid: Prediction.null(question_id=qid) for qid in examples.keys()}
src/models/base_qa_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+ from src.etl.types import QAExample, Prediction
4
+
5
+
6
+ class QAModel(ABC):
7
+ """Basic contract dictating specific QA model implementation requirements."""
8
+
9
+ @abstractmethod
10
+ def train(
11
+ self,
12
+ train_examples: Dict[str, QAExample],
13
+ val_examples: Optional[Dict[str, QAExample]] = None,
14
+ ) -> None:
15
+ """
16
+ Trains the model; assumes uniqueness of keys of train_examples (unique question IDs).
17
+ """
18
+ raise NotImplementedError
19
+
20
+ @abstractmethod
21
+ def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
22
+ """
23
+ Produces one Prediction per question ID.
24
+ """
25
+ raise NotImplementedError
src/models/bert_based_model.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains functionality adapting a general-purpose BERT-type model
3
+ for the QA task. The BertBasedQAModel fully aligns with the structure of
4
+ other models (i.e., sub-classing QAModel for consistency); and stores a custom
5
+ QAModule which specifies the wiring of the general-purpose model's representations
6
+ with the linear NN layer needed for the QA task.
7
+
8
+ Benefits:
9
+ - Facilitates a **plug-and-play** selection of the underlying encoder model.
10
+ - Follows a clean, composition pattern, avoiding double inheritance of both
11
+ QAModel and torch.nn.Module which may introduce unnecessary complexity
12
+ (e.g., which __init__() is called, which train() is called, etc.)
13
+ """
14
+
15
+ import torch
16
+ import random
17
+ import json
18
+ import numpy as np
19
+ from dataclasses import asdict
20
+ from pathlib import Path
21
+ from typing import Dict, Optional, List, Tuple
22
+ from transformers import AutoTokenizer, AutoModel
23
+ from transformers.tokenization_utils_base import BatchEncoding
24
+ from torch.utils.data import Dataset, DataLoader
25
+
26
+ from src.models.base_qa_model import QAModel
27
+ from src.config.model_configs import BertQAConfig
28
+ from src.etl.types import QAExample, Prediction
29
+ from src.evaluation.evaluator import Evaluator, Metrics
30
+ from src.utils.constants import DEBUG_SEED
31
+
32
+
33
+ def set_seed(seed: int = DEBUG_SEED) -> None:
34
+ """
35
+ Set random seeds for reproducibility across Python, NumPy, and PyTorch.
36
+ NOTE - this is mainly to facilitate experimentation progress; options such
37
+ as torch.backends.cudnn.benchmark = False may hurt performance and thus running
38
+ this function may need to be skipped in production.
39
+
40
+ Relevant resources:
41
+ - https://stackoverflow.com/questions/67581281/does-torch-manual-seed-include-the-operation-of-torch-cuda-manual-seed-all
42
+ - https://docs.pytorch.org/docs/stable/notes/randomness.html
43
+
44
+ # TODO - move to utilities file
45
+ """
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+
50
+ # CUDA (NVIDIA GPUs)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed_all(seed)
53
+ torch.backends.cudnn.deterministic = True
54
+ torch.backends.cudnn.benchmark = False
55
+
56
+ # MPS (Apple Silicon)
57
+ if torch.backends.mps.is_available():
58
+ torch.mps.manual_seed(seed)
59
+
60
+
61
+ class QADataset(Dataset):
62
+ """
63
+ Minimal wrapper to make Dict[str, QAExample] compatible with DataLoader.
64
+ Facilitates batch processing during training (e.g., no manual index
65
+ calculations to compute batch boundaries).
66
+
67
+ # TODO - move to utilities file
68
+ """
69
+
70
+ def __init__(self, examples_dict: Dict[str, QAExample]):
71
+ """DataLoader will call __getitem__(0), __getitem__(1), etc."""
72
+ self.examples = list(examples_dict.values())
73
+
74
+ def __len__(self) -> int:
75
+ """Returns total number of examples. DataLoader uses this for batching."""
76
+ return len(self.examples)
77
+
78
+ def __getitem__(self, idx: int) -> QAExample:
79
+ """Returns a single example at the given index."""
80
+ return self.examples[idx]
81
+
82
+
83
+ class BertBasedQAModel(QAModel):
84
+
85
+ def __init__(self, config: BertQAConfig) -> None:
86
+ super().__init__()
87
+ # Reproducible weight initialization
88
+ set_seed()
89
+ assert isinstance(config, BertQAConfig), "Incompatible configuration object."
90
+ self.config = config
91
+
92
+ self.tokenizer = AutoTokenizer.from_pretrained(
93
+ self.config.backbone_name, use_fast=True
94
+ )
95
+ self.qa_module = QAModule(config=self.config)
96
+
97
+ # Sanity check to ensure that [CLS] token is always at position 0;
98
+ # This assumption is used in the code for predicting non-answerable questions
99
+ test_encoding = self.tokenizer("testQ", "testC", return_tensors="pt")
100
+ assert (
101
+ # [0, 0] --> [first (and only) example of batch, first sequence token for example]
102
+ test_encoding["input_ids"][0, 0].item()
103
+ == self.tokenizer.cls_token_id
104
+ ), "Model doesn't follow BERT's [CLS]-at-position-0 convention."
105
+
106
+ @classmethod
107
+ def load_from_experiment(
108
+ cls, experiment_dir: Path, config_class, device: str = "mps"
109
+ ):
110
+ """
111
+ Loads model from the experiment tracking directory.
112
+
113
+ experiment_dir: Path to the experiment (e.g., 'experiments/<date_time>_bert-base_ALL_articles')
114
+ device: by default we load into Apple MPS for local experimentation with predictions (e.g., threshold tuning)
115
+ """
116
+ experiment_dir = Path(experiment_dir)
117
+ model_dir = experiment_dir / "model"
118
+ if not model_dir.exists():
119
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
120
+
121
+ print(f"\nLoading model from experiment: {experiment_dir.name}")
122
+ with open(experiment_dir / "config.json", "r") as f:
123
+ config_dict = json.load(f)
124
+
125
+ # Override device
126
+ config_dict["device"] = device
127
+ config = config_class(**config_dict)
128
+
129
+ model = cls(config)
130
+
131
+ tokenizer_path = model_dir / "tokenizer"
132
+ if not tokenizer_path.exists():
133
+ raise FileNotFoundError(f"Tokenizer not found: {tokenizer_path}")
134
+ model.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
135
+
136
+ weights_path = model_dir / "pytorch_model.bin"
137
+ if not weights_path.exists():
138
+ raise FileNotFoundError(f"Model weights not found: {weights_path}")
139
+ state_dict = torch.load(weights_path, map_location=device)
140
+ model.qa_module.load_state_dict(state_dict)
141
+
142
+ model.qa_module.eval()
143
+ print("Model loaded succesfully and set to eval mode.")
144
+ return model
145
+
146
+ def train(
147
+ self,
148
+ train_examples: Optional[Dict[str, QAExample]] = None,
149
+ val_examples: Optional[Dict[str, QAExample]] = None,
150
+ ) -> None:
151
+ """
152
+ Trains the QA model on provided training examples.
153
+ """
154
+ # Reproducible training loop
155
+ set_seed()
156
+
157
+ # Ensuring dropout is properly configured if it is applied
158
+ self.qa_module.train()
159
+
160
+ assert train_examples is not None, "Training examples cannot be None."
161
+ assert len(train_examples) > 0, "Training examples cannot be empty."
162
+
163
+ self._print_training_setup(train_examples, val_examples, self.config)
164
+
165
+ # Adam is standard for BERT-type models; AdamW handles weight decay better
166
+ optimizer = torch.optim.AdamW(
167
+ self.qa_module.parameters(), # Trains both encoder and linear head
168
+ lr=self.config.learning_rate,
169
+ )
170
+ # ignore_index=-1: Skip examples where answer wasn't found in tokenization;
171
+ # see _extract_gold_positions() for details
172
+ loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
173
+ dataset = QADataset(train_examples)
174
+ # Should shuffle to avoid bias towards certain combination of examples within a batch
175
+ dataloader = DataLoader(
176
+ dataset,
177
+ batch_size=self.config.batch_size,
178
+ shuffle=True,
179
+ collate_fn=lambda batch: batch, # Return list as-is, don't collate
180
+ )
181
+ print(f"Total batches per epoch: {len(dataloader)}")
182
+ print(f"{'='*70}\n")
183
+
184
+ for epoch in range(self.config.num_epochs):
185
+ print(f"{'='*70}")
186
+ print(f"EPOCH {epoch + 1}/{self.config.num_epochs}")
187
+ print(f"{'='*70}")
188
+ total_loss = 0.0
189
+
190
+ # Logging/debugging: accumulate examples ignored in the loss due to answer truncation
191
+ set_truncated_examples = set()
192
+ for batch_idx, batch_examples in enumerate(dataloader):
193
+ # convert to the format expected by the _prepare_batch() function
194
+ batch_dict = {ex.question_id: ex for ex in batch_examples}
195
+ qids, _, _, encoded = self._prepare_batch(batch_dict)
196
+ assert (
197
+ len(qids) == encoded["input_ids"].shape[0] == len(batch_examples)
198
+ ), "Training shape mismatch after batch prepare."
199
+
200
+ gold_starts, gold_ends = self._extract_gold_positions(
201
+ batch_examples, encoded, set_truncated_examples
202
+ )
203
+
204
+ device = next(self.qa_module.parameters()).device
205
+ gold_starts = gold_starts.to(device)
206
+ gold_ends = gold_ends.to(device)
207
+
208
+ start_logits, end_logits = self.qa_module(
209
+ input_ids=encoded["input_ids"],
210
+ attention_mask=encoded.get("attention_mask"),
211
+ token_type_ids=encoded.get("token_type_ids"),
212
+ )
213
+ # Shape should match (batch_size, sequence_length)
214
+ expected_shape = (len(batch_examples), encoded["input_ids"].shape[1])
215
+ assert (
216
+ start_logits.shape == expected_shape
217
+ ), f"start_logits shape {start_logits.shape} != expected {expected_shape}"
218
+ assert (
219
+ end_logits.shape == expected_shape
220
+ ), f"end_logits shape {end_logits.shape} != expected {expected_shape}"
221
+
222
+ start_loss = loss_fn(start_logits, gold_starts)
223
+ end_loss = loss_fn(end_logits, gold_ends)
224
+
225
+ # Similar to how the original BERT paper defines the objective for SQuAD (Section 4.2)
226
+ loss = (start_loss + end_loss) / 2.0
227
+ assert loss.dim() == 0, f"Loss should be scalar, got shape {loss.shape}"
228
+
229
+ # --- Standard backprop flow ---
230
+ # Zero out/initialize gradients from previous batch
231
+ optimizer.zero_grad()
232
+ # Backpropagate gradients
233
+ loss.backward()
234
+ # Update model parameters using computed grads
235
+ optimizer.step()
236
+ total_loss += loss.item()
237
+
238
+ if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(dataloader):
239
+ avg_loss = total_loss / (batch_idx + 1)
240
+ print(
241
+ f" Batch {batch_idx + 1}/{len(dataloader)} | Avg Loss: {avg_loss:.4f}"
242
+ )
243
+
244
+ avg_epoch_loss = total_loss / len(dataloader)
245
+ # Currently ignored returned metrics; TODO - use them later for early stopping
246
+ _, _ = self._print_epoch_summary(
247
+ epoch=epoch + 1,
248
+ total_epochs=self.config.num_epochs,
249
+ avg_loss=avg_epoch_loss,
250
+ num_truncated=len(set_truncated_examples),
251
+ train_examples=train_examples,
252
+ val_examples=val_examples,
253
+ )
254
+
255
+ print("Training Completed.")
256
+ self.qa_module.eval()
257
+
258
+ def _print_epoch_summary(
259
+ self,
260
+ epoch: int,
261
+ total_epochs: int,
262
+ avg_loss: float,
263
+ num_truncated: int,
264
+ train_examples: Dict[str, QAExample],
265
+ val_examples: Optional[Dict[str, QAExample]] = None,
266
+ ) -> Tuple[Metrics, Optional[Metrics]]:
267
+ if num_truncated > 0:
268
+ print(
269
+ f"{num_truncated} examples truncated throughout the epoch."
270
+ f" Start & end answer tokens could not be identified."
271
+ )
272
+ print(f"\nEpoch {epoch}/{total_epochs} Complete | Average Loss: {avg_loss:.4f}")
273
+ train_metrics = self._evaluate_and_print(train_examples, "Training")
274
+ val_metrics = None
275
+ if val_examples is not None:
276
+ val_metrics = self._evaluate_and_print(val_examples, "Validation")
277
+
278
+ # Always resume training mode after evaluation
279
+ self.qa_module.train()
280
+ print(f"{'='*70}\n")
281
+ return train_metrics, val_metrics
282
+
283
+ def _evaluate_and_print(
284
+ self, examples: Dict[str, QAExample], split_name: str
285
+ ) -> Metrics:
286
+ print(f"Evaluating on {split_name} set...")
287
+ predictions = self.predict(examples)
288
+ metrics = Evaluator().evaluate(predictions, examples)
289
+ print(
290
+ f"{split_name} | EM: {metrics.exact_score:.2f}%, F1: {metrics.f1_score:.2f}%"
291
+ )
292
+ return metrics
293
+
294
+ def _extract_gold_positions(
295
+ self,
296
+ examples: List[QAExample],
297
+ encoded: BatchEncoding,
298
+ set_truncated_examples: set[str],
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ """
301
+ Maps character-level answer positions to token-level positions.
302
+ In particular, for each example, the function computes (all offsets are start-inclusive, end-exclusive):
303
+ - the answer offset within the context: [char_start, char_end)
304
+ - each individual token's offset within the context: [token_char_start, token_char_end)
305
+
306
+ For two ranges [A, B) and [C, D) to overlap:
307
+ 1. The first range should start before the second ends (A < D)
308
+ 2. The second range should start before the first ends (C < B)
309
+ These are the conditions the function utilizes to determine an answer's overlap with a specific token.
310
+
311
+ Finally, the function picks the FIRST and LAST tokens overlapping with the answer:
312
+ those tokens can fully determine the answer and align with the QA training objective.
313
+
314
+ Returns:
315
+ - gold_starts: Tensor (size: batch size) with token index for answer start
316
+ - gold_ends: Tensor (size: batch size) with token index for answer end
317
+ """
318
+ offsets = encoded["offset_mapping"].tolist()
319
+ batch_size = len(examples)
320
+ assert (
321
+ len(offsets) == batch_size
322
+ ), f"Offset mapping size {len(offsets)} != batch size {batch_size}"
323
+
324
+ # Accumulate gold positions for each example in the batch
325
+ gold_starts = []
326
+ gold_ends = []
327
+ for i, example in enumerate(examples):
328
+
329
+ # Following BERT paper (Section 4.3) - point to [CLS] token (0, 0) for unanswerables
330
+ if example.is_impossible:
331
+ gold_starts.append(0)
332
+ gold_ends.append(0)
333
+ continue
334
+ assert (
335
+ len(example.answer_starts) > 0
336
+ ), f"Answerable question {example.question_id} without valid answers."
337
+
338
+ # Simply pick the first available answer (even if multiple are provided)
339
+ answer_text = example.answer_texts[0]
340
+ char_start = example.answer_starts[0]
341
+ char_end = char_start + len(answer_text)
342
+
343
+ token_start = None # Will store first token overlapping with the answer
344
+ token_end = None # Will store last token overlapping with the answer
345
+ for token_idx, (token_char_start, token_char_end) in enumerate(offsets[i]):
346
+
347
+ # skip special tokens ([CLS], [SEP], ...)
348
+ if token_char_start == 0 and token_char_end == 0:
349
+ continue
350
+
351
+ # Need first overlapping token -> check if None
352
+ if token_start is None and token_char_end > char_start:
353
+ token_start = token_idx
354
+
355
+ # Need last overlapping token -> check exhaustively
356
+ if token_char_start < char_end:
357
+ token_end = token_idx
358
+
359
+ if token_start is None or token_end is None:
360
+ # print(
361
+ # f"Warning! Answer truncated for {example.question_id}, skipping in loss"
362
+ # )
363
+ set_truncated_examples.add(example.question_id)
364
+ # Answer was truncated -> use -1 such that it is ignored for loss computation
365
+ gold_starts.append(-1)
366
+ gold_ends.append(-1)
367
+ continue
368
+ assert (
369
+ token_start <= token_end
370
+ ), f"Invalid token span: start {token_start} > end {token_end}"
371
+
372
+ gold_starts.append(token_start)
373
+ gold_ends.append(token_end)
374
+
375
+ gold_starts_tensor = torch.tensor(gold_starts, dtype=torch.long)
376
+ gold_ends_tensor = torch.tensor(gold_ends, dtype=torch.long)
377
+ assert (
378
+ len(examples) == len(gold_starts_tensor) == len(gold_ends_tensor)
379
+ ), "Ground-truth token shape mismatch."
380
+ return gold_starts_tensor, gold_ends_tensor
381
+
382
+ def predict(
383
+ self, examples: Dict[str, QAExample], threshold_override: Optional[float] = None
384
+ ) -> Dict[str, Prediction]:
385
+ """
386
+ Wrapper that automatically chunks large prediction requests to avoid OOM.
387
+ """
388
+ self.qa_module.eval()
389
+ assert isinstance(examples, dict), "Incompatible input examples type."
390
+ assert len(examples) > 0, "No examples to run prediction on."
391
+
392
+ eval_batch_size = self.config.eval_batch_size
393
+ if len(examples) <= eval_batch_size:
394
+ return self._predict_batch(examples, threshold_override)
395
+
396
+ all_qids = list(examples.keys())
397
+ all_predictions = {}
398
+ # Chunking larger batches to avoid OOM errors
399
+ for i in range(0, len(all_qids), eval_batch_size):
400
+ batch_qids = all_qids[i : i + eval_batch_size]
401
+ batch_examples = {qid: examples[qid] for qid in batch_qids}
402
+ all_predictions.update(
403
+ self._predict_batch(batch_examples, threshold_override)
404
+ )
405
+
406
+ return all_predictions
407
+
408
+ def _predict_batch(
409
+ self, examples: Dict[str, QAExample], threshold_override: Optional[float] = None
410
+ ) -> Dict[str, Prediction]:
411
+ """
412
+ Processes a single batch of examples:
413
+ encapsulates the forward pass + logic to determine the final model's response
414
+ based on the predicted logits for each token being the start/end of the true answer.
415
+ """
416
+ # Offers overriding the default threshold if this is provided
417
+ threshold = (
418
+ threshold_override
419
+ if threshold_override is not None
420
+ else self.config.no_answer_threshold
421
+ )
422
+
423
+ # 1) Batch tokenization
424
+ qids, _, contexts, encoded = self._prepare_batch(examples)
425
+
426
+ # 2) Forward pass
427
+ # Inference mode - no gradient calculation
428
+ with torch.no_grad():
429
+ start_logits, end_logits = self.qa_module(
430
+ input_ids=encoded["input_ids"],
431
+ attention_mask=encoded.get("attention_mask"),
432
+ token_type_ids=encoded.get("token_type_ids"),
433
+ )
434
+
435
+ # 3) Create context mask: (batch_size, max_sequence_length) boolean tensor;
436
+ # Valid positions: context tokens + [CLS] (for unanswerables);
437
+ # Masked: question tokens, [SEP], padding
438
+ if encoded.get("token_type_ids") is not None:
439
+ # token_type_ids == 1 means context segment (Vs question segment); filter out padding tokens
440
+ context_mask = (encoded["token_type_ids"] == 1) & (
441
+ encoded["attention_mask"] == 1
442
+ )
443
+ else:
444
+ # Fallback for models without token_type_ids (shouldn't happen with BERT)
445
+ context_mask = encoded["attention_mask"] == 1
446
+ # Explicitly allow [CLS] token at position 0 -> predicted token for unanswerables
447
+ context_mask[:, 0] = True
448
+ context_mask = context_mask.to(self.config.device)
449
+
450
+ # Apply an extreme negative value to the position associated with filtered-out tokens;
451
+ # avoid neg-inf -> pathological cases where softmax over all neg-inf logits would result in all nans
452
+ MIN_NUMBER = torch.finfo(start_logits.dtype).min
453
+ start_logits = start_logits.masked_fill(~context_mask, MIN_NUMBER)
454
+ end_logits = end_logits.masked_fill(~context_mask, MIN_NUMBER)
455
+
456
+ # 4) Simplistic/greedy selection of tokens for start/end of the predicted response;
457
+ # Note that [CLS] is also available to be picked as the most probable token
458
+ best_start_indices = start_logits.argmax(dim=1)
459
+ best_end_indices = end_logits.argmax(dim=1)
460
+
461
+ # 5) Extract predictions from token positions
462
+ # offsets reveals where each token maps in the original text;
463
+ # example: token "apple" at token position 3 may map to text[10:15]
464
+ offsets = encoded["offset_mapping"].tolist()
465
+ predictions = {}
466
+ for i, qid in enumerate(qids):
467
+ # edge case - no valid context tokens --> return unanswerable (excluding [CLS] at position 0)
468
+ if not context_mask[i, 1:].any():
469
+ predictions[qid] = Prediction.null(question_id=qid)
470
+ continue
471
+
472
+ start_idx = best_start_indices[i].item()
473
+ end_idx = best_end_indices[i].item()
474
+
475
+ # Compute null score vs best span score (as per the BERT paper, Section 4.3)
476
+ null_score = start_logits[i, 0].item() + end_logits[i, 0].item()
477
+ best_span_score = (
478
+ start_logits[i, start_idx].item() + end_logits[i, end_idx].item()
479
+ )
480
+ # Predict no-answer if null score exceeds best span by threshold
481
+ if best_span_score <= null_score + threshold:
482
+ predictions[qid] = Prediction.null(question_id=qid)
483
+ continue
484
+
485
+ # NOTE: When end_idx < start_idx, the BERT paper specifies searching
486
+ # all valid spans to find the maximum scoring one. For efficiency and simplicity
487
+ # of an initial implementation, we return null. When end_idx >= start_idx, no
488
+ # exhaustive search is necessary (simply picking the best start/end index suffices).
489
+ if end_idx < start_idx:
490
+ predictions[qid] = Prediction.null(question_id=qid)
491
+ continue
492
+
493
+ # Map token positions -> character positions in the original text
494
+ start_char, _ = offsets[i][start_idx] # Character start of first token
495
+ _, end_char = offsets[i][end_idx] # Character end of last token
496
+
497
+ # Special tokens (such as [CLS], [SEP]) have offset [0, 0];
498
+ # mark as unanswerable if we selected a special token
499
+ if start_char == 0 and end_char == 0:
500
+ predictions[qid] = Prediction.null(question_id=qid)
501
+ continue
502
+
503
+ assert end_char >= start_char, (
504
+ f"BUG: Invalid character span [{start_char}, {end_char}] "
505
+ f"for valid token span [{start_idx}, {end_idx}] in question {qid}. "
506
+ f"This indicates a problem with offset mapping or token masking."
507
+ )
508
+
509
+ # Extract answer text from original context
510
+ answer_text = contexts[i][start_char:end_char].strip()
511
+ # reject whitespace-only responses
512
+ if not answer_text:
513
+ predictions[qid] = Prediction.null(question_id=qid)
514
+ continue
515
+
516
+ # Create final prediction
517
+ predictions[qid] = Prediction(
518
+ question_id=qid,
519
+ predicted_answer=answer_text,
520
+ confidence=1.0, # TODO - use a better way to estimate uncertainty
521
+ is_impossible=False,
522
+ )
523
+ return predictions
524
+
525
+ def _prepare_batch(
526
+ self, examples: Dict[str, QAExample]
527
+ ) -> Tuple[List[str], List[str], List[str], BatchEncoding]:
528
+ """
529
+ Extracts questions and contexts in consistent order, then tokenizes them.
530
+ """
531
+ qids = list(examples.keys())
532
+ questions = [examples[qid].question for qid in qids]
533
+ contexts = [examples[qid].context for qid in qids]
534
+ encoded = self._encode_pairs(questions, contexts)
535
+ return qids, questions, contexts, encoded
536
+
537
+ def _encode_pairs(self, questions: list[str], contexts: list[str]) -> BatchEncoding:
538
+ """
539
+ Standardizes tokenization across all stages (train/inference).
540
+ For more information, refer to the HF documentation, for example see:
541
+ https://huggingface.co/docs/transformers/pad_truncation regarding sequence padding/trunctation.
542
+ """
543
+ assert len(questions) == len(
544
+ contexts
545
+ ), "Question and context lists are incompatible."
546
+ return self.tokenizer(
547
+ text=questions,
548
+ text_pair=contexts,
549
+ truncation="only_second", # prioritizing truncating context Vs question
550
+ max_length=self.config.max_sequence_length,
551
+ padding="max_length", # pads to uniform length for conversion to fixed-size tensors
552
+ return_offsets_mapping=True, # returns (char_start, char_end) for each token
553
+ return_tensors="pt",
554
+ )
555
+
556
+ @staticmethod
557
+ def _print_training_setup(
558
+ train_examples: Dict[str, QAExample],
559
+ val_examples: Optional[Dict[str, QAExample]],
560
+ config: BertQAConfig,
561
+ ) -> None:
562
+ """Print training setup information including data splits and configuration."""
563
+ answerable_count = sum(
564
+ 1 for ex in train_examples.values() if not ex.is_impossible
565
+ )
566
+ unanswerable_count = len(train_examples) - answerable_count
567
+
568
+ print(f"\n{'='*70}")
569
+ print(f"TRAINING SETUP")
570
+ print(f"{'='*70}")
571
+ print(f"Total examples: {len(train_examples)}")
572
+ print(f" Answerable: {answerable_count}")
573
+ print(f" Unanswerable: {unanswerable_count}")
574
+ assert len(train_examples) > 0, "No training examples!"
575
+
576
+ if val_examples is not None:
577
+ val_answerable = sum(
578
+ 1 for ex in val_examples.values() if not ex.is_impossible
579
+ )
580
+ val_unanswerable = len(val_examples) - val_answerable
581
+ print(
582
+ f"Validation: {len(val_examples)} total ({val_answerable} answerable, {val_unanswerable} unanswerable)"
583
+ )
584
+
585
+ print(f"\nConfiguration:")
586
+ print(json.dumps(asdict(config), indent=2))
587
+ print(f"{'='*70}\n")
588
+
589
+
590
+ class QAModule(torch.nn.Module):
591
+ """
592
+ Defines the initialization & wiring of a general-purpose encoder with a linear NN layer
593
+ in order to extract logits reflecting the probability of each token being
594
+ the start/end of the answer.
595
+ """
596
+
597
+ def __init__(self, config: BertQAConfig) -> None:
598
+ super().__init__()
599
+ assert isinstance(config, BertQAConfig), "Incompatible configuration object."
600
+ self.encoder = AutoModel.from_pretrained(config.backbone_name)
601
+ # Extracting hidden_size automatically from the encoder to support
602
+ # plug-and-play picking of the exact encoder type (e.g., DistilBERT, BERT, etc)
603
+ self.linear_head = torch.nn.Linear(
604
+ in_features=self.encoder.config.hidden_size, out_features=2
605
+ )
606
+
607
+ # Device placement
608
+ self.to(config.device)
609
+
610
+ def forward(
611
+ self,
612
+ input_ids: torch.Tensor,
613
+ attention_mask: Optional[torch.Tensor] = None,
614
+ token_type_ids: Optional[torch.Tensor] = None,
615
+ ) -> tuple[torch.Tensor, torch.Tensor]:
616
+ """
617
+ input_ids: tokenized integer IDs from the vocabulary
618
+ attention_mask: binary mask reflecting actual token Vs padding token
619
+ token_type_ids: binary mask reflecting the segment: sentence A Vs sentence B
620
+ """
621
+ # Ensure all inputs live on the same device as the module itself
622
+ dev = next(self.parameters()).device
623
+ input_ids = input_ids.to(dev)
624
+ if attention_mask is not None:
625
+ attention_mask = attention_mask.to(dev)
626
+ if token_type_ids is not None:
627
+ token_type_ids = token_type_ids.to(dev)
628
+
629
+ encoder_output = self.encoder(
630
+ input_ids=input_ids,
631
+ attention_mask=attention_mask,
632
+ token_type_ids=token_type_ids,
633
+ )
634
+ # Retrieve the (B, L, H) token representations of the encoder's last layer
635
+ encoder_output_embeddings = encoder_output.last_hidden_state
636
+ # Linear projection layer; tensor sizes: (B, L, H) --> (B, L, 2)
637
+ logits = self.linear_head(encoder_output_embeddings)
638
+ start_logits, end_logits = logits[:, :, 0], logits[:, :, 1]
639
+ return start_logits, end_logits
src/models/sentence_embedding_model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains a simple baseline for the QA system.
3
+ """
4
+
5
+ import spacy
6
+ from typing import Dict, Optional, List
7
+ from src.config.model_configs import SentenceEmbeddingModelConfig
8
+ from src.models.base_qa_model import QAModel
9
+ from sentence_transformers import SentenceTransformer, util
10
+ from src.etl.types import Prediction, QAExample
11
+
12
+
13
+ class SentenceEmbeddingQAModel(QAModel):
14
+ """
15
+ Minimal embedding-based baseline: picks the single best matching sentence from the
16
+ context as the response. Uses sentence-transformers (https://sbert.net/) as
17
+ embedding-based representations of each of the context sentences as well as
18
+ the question itself. The sentence associated with the highest cosine similarity score
19
+ against the question is returned as the response.
20
+ """
21
+
22
+ def __init__(self, config: SentenceEmbeddingModelConfig) -> None:
23
+ super().__init__()
24
+ assert isinstance(
25
+ config, SentenceEmbeddingModelConfig
26
+ ), "Incompatible configuration object."
27
+ self.config = config
28
+ self._st_model = SentenceTransformer(
29
+ model_name_or_path=self.config.sentence_model_name,
30
+ device=self.config.device,
31
+ )
32
+ self._nlp = spacy.load("en_core_web_sm")
33
+
34
+ def train(
35
+ self,
36
+ train_examples: Optional[Dict[str, QAExample]] = None,
37
+ val_examples: Optional[Dict[str, QAExample]] = None,
38
+ ) -> None:
39
+ """
40
+ Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
41
+ """
42
+ return
43
+
44
+ def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
45
+ assert isinstance(examples, dict), "Incompatible input examples type."
46
+
47
+ predictions: Dict[str, Prediction] = {}
48
+ for qid, example in examples.items():
49
+ sentences = self._split_sentences(example.context)
50
+ if not sentences:
51
+ predictions[qid] = Prediction.null(question_id=qid)
52
+ continue
53
+
54
+ q_emb = self._st_model.encode(
55
+ example.question, convert_to_tensor=True, normalize_embeddings=True
56
+ )
57
+ s_emb = self._st_model.encode(
58
+ sentences, convert_to_tensor=True, normalize_embeddings=True
59
+ )
60
+ scores = util.cos_sim(q_emb, s_emb).squeeze(0)
61
+ top_index = int(scores.argmax().item())
62
+ best_sentence = sentences[top_index]
63
+ best_score = float(scores[top_index])
64
+
65
+ if best_score < self.config.no_answer_threshold:
66
+ predictions[qid] = Prediction.null(question_id=qid)
67
+ else:
68
+ predictions[qid] = Prediction(
69
+ question_id=qid,
70
+ predicted_answer=best_sentence,
71
+ confidence=best_score,
72
+ is_impossible=False,
73
+ )
74
+ return predictions
75
+
76
+ def _split_sentences(self, text: str) -> List[str]:
77
+ """spacy-based sentence segmentation"""
78
+ text = (text or "").strip()
79
+ if not text:
80
+ return []
81
+ doc = self._nlp(text)
82
+ return [s.text.strip() for s in doc.sents if s.text.strip()]
src/pipeline/__init__.py ADDED
File without changes
src/pipeline/qa_runner.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains a simple experimentation pipeline for the QA system.
3
+
4
+ Benefits:
5
+ - Completely **plug-and-play**: the users can easily replace models and configs without
6
+ needing to change pipeline or other code.
7
+ - Automates experiment tracking/versioning: facilitates experimental iteration.
8
+ - Offers data splitting routines promoting model generalization & objective perf measuring:
9
+ a) initial training set gets split into: 'train' Vs 'val' subsets which DO NOT share
10
+ common articles, such that the 'val' set simulates actual held-out article performance;
11
+ b) initial 'dev' set can remain untouched until the very end for objective perf measuring
12
+ """
13
+
14
+ import pandas as pd
15
+ from typing import Tuple, Dict, Optional
16
+ from pathlib import Path
17
+ import sys
18
+ from io import StringIO
19
+ from src.utils.constants import (
20
+ EXPERIMENTS_DIR,
21
+ DEV_DATA_PATH,
22
+ TRAIN_DATA_PATH,
23
+ Col,
24
+ DEBUG_SEED,
25
+ )
26
+ from src.etl.squad_v2_loader import load_squad_v2_df, df_to_examples_map
27
+ from src.models.base_qa_model import QAModel
28
+ from src.etl.types import QAExample
29
+ from src.config.model_configs import BaseModelConfig, BertQAConfig
30
+ from src.evaluation.evaluator import Evaluator
31
+ from src.utils.experiment_snapshot import ExperimentSnapshot
32
+
33
+ DEFAULT_VAL_SET_FRACTION = 0.1
34
+
35
+
36
+ class Tee:
37
+ """
38
+ Based on: https://stackoverflow.com/questions/616645/how-to-duplicate-sys-stdout-to-a-log-file
39
+ Duplicates output to multiple destinations such that experiment tracking can include notebook output.
40
+ """
41
+
42
+ def __init__(self, *files):
43
+ self.files = files
44
+
45
+ def write(self, obj):
46
+ # Writes to all of the streams
47
+ for f in self.files:
48
+ f.write(obj)
49
+ f.flush()
50
+
51
+ def flush(self):
52
+ # Flushes all of the streams (ensures text appears immediately)
53
+ for f in self.files:
54
+ f.flush()
55
+
56
+
57
+ def run_qa_experiment(
58
+ experiment_name: str,
59
+ model: QAModel,
60
+ debug_limit: Optional[int] = None,
61
+ val_fraction: float = DEFAULT_VAL_SET_FRACTION,
62
+ ) -> Tuple[ExperimentSnapshot, Path, Optional[pd.DataFrame], Optional[pd.DataFrame]]:
63
+ """
64
+ Basic pipeline for running a QA system experiment.
65
+
66
+ To facilitate debugging:
67
+ 1. The function can limit the #training examples processed
68
+ 2. The sampled ETLed input DF is also provided as part of the function return
69
+
70
+ Note that debug_limit is only applied to the training instances; i.e., dev set is not capped.
71
+ """
72
+ # TODO - use proper logging for all of this
73
+ # Capture output to StringIO while printing to console
74
+ log_capture = StringIO()
75
+ original_stdout = sys.stdout
76
+ sys.stdout = Tee(sys.stdout, log_capture)
77
+
78
+ try:
79
+ if debug_limit is not None:
80
+ print(f"{debug_limit} articles will be considered from training in total.")
81
+ else:
82
+ print("All articles from training set are considered.")
83
+ assert TRAIN_DATA_PATH.exists(), "Unspecified train data location."
84
+ # Note that df_val can be returned for debugging: ignored for now
85
+ (train_examples, val_examples), (df_train, _) = _load_examples(
86
+ path=TRAIN_DATA_PATH, debug_limit=debug_limit, split_fraction=val_fraction
87
+ )
88
+
89
+ assert DEV_DATA_PATH.exists(), "Unspecified dev data location."
90
+ # do NOT split dev set -> split_fraction is explicitly set to None
91
+ (dev_examples, _), (df_dev, _) = _load_examples(
92
+ path=DEV_DATA_PATH, debug_limit=None, split_fraction=None
93
+ )
94
+
95
+ # Sanity checking for non-empty data splits
96
+ assert len(train_examples) > 0, "train_examples is empty."
97
+ assert len(dev_examples) > 0, "dev_examples is empty."
98
+
99
+ if val_examples is not None:
100
+ model.train(train_examples, val_examples=val_examples)
101
+ else:
102
+ model.train(train_examples)
103
+ predictions = model.predict(dev_examples)
104
+ metrics = Evaluator().evaluate(predictions=predictions, examples=dev_examples)
105
+
106
+ # Save experiment
107
+ config = getattr(model, "config", None)
108
+ assert isinstance(config, BaseModelConfig), "Incompatible Config type."
109
+ snapshot = ExperimentSnapshot(
110
+ experiment_name=experiment_name,
111
+ config=config,
112
+ predictions=predictions,
113
+ metrics=metrics,
114
+ model=model,
115
+ )
116
+
117
+ print("\n" + "=" * 70)
118
+ print("FINAL DEV SET RESULTS")
119
+ print("=" * 70)
120
+ print(f"Exact Match (EM): {snapshot.metrics.exact_score:.2f}%")
121
+ print(f"F1 Score: {snapshot.metrics.f1_score:.2f}%")
122
+ print(f"Total dev examples: {snapshot.metrics.total_num_instances}")
123
+ print("=" * 70)
124
+
125
+ run_dir = snapshot.save(experiments_root=EXPERIMENTS_DIR)
126
+ (run_dir / "training_log.txt").write_text(
127
+ log_capture.getvalue(), encoding="utf-8"
128
+ )
129
+ return snapshot, run_dir, df_train, df_dev
130
+ finally:
131
+ # Restore stdout after running the experiment
132
+ sys.stdout = original_stdout
133
+
134
+
135
+ def create_experiment_name(
136
+ model_name_short: str, config: BertQAConfig, num_articles: Optional[int] = None
137
+ ) -> str:
138
+ assert (
139
+ model_name_short in config.backbone_name
140
+ ), "Inconsistent model name used for experiment tracking Vs actual model name."
141
+ experiment_name = (
142
+ f"{model_name_short}_{num_articles}_articles"
143
+ if num_articles is not None
144
+ else f"{model_name_short}_ALL_articles"
145
+ )
146
+ return experiment_name
147
+
148
+
149
+ def _load_examples(
150
+ path: Path, debug_limit: int | None, split_fraction: float | None
151
+ ) -> Tuple[
152
+ Tuple[Dict[str, QAExample], Dict[str, QAExample] | None],
153
+ Tuple[pd.DataFrame, pd.DataFrame | None],
154
+ ]:
155
+ """
156
+ Returns both a dict with QAExample objects and the associated DF for debugging.
157
+ Both the debug_limit and the split_fraction are operating on the ARTICLE level Vs
158
+ individual example/question level.
159
+ - debug_limit: caps the #articles returned for debugging/easier experimentation
160
+ - split_fraction: enables train/val splitting based on initial training data
161
+ """
162
+ df = load_squad_v2_df(path)
163
+
164
+ if debug_limit is not None:
165
+ all_titles = df[Col.TITLE.value].unique()
166
+ assert (
167
+ 1 <= debug_limit <= len(all_titles)
168
+ ), f"debug_limit={debug_limit} exceeds {len(all_titles)} available articles"
169
+
170
+ # df = df.sample(n=debug_limit, random_state=DEBUG_SEED).copy()
171
+ sampled_titles = pd.Series(all_titles).sample(
172
+ n=debug_limit, random_state=DEBUG_SEED
173
+ )
174
+ df = df[df[Col.TITLE.value].isin(sampled_titles)].copy()
175
+
176
+ if split_fraction is not None:
177
+ df_train, df_val = split_by_title(df, split_fraction)
178
+ train_examples = df_to_examples_map(df_train)
179
+ val_examples = df_to_examples_map(df_val)
180
+ return (train_examples, val_examples), (df_train, df_val)
181
+ else:
182
+ examples = df_to_examples_map(df)
183
+ return (examples, None), (df, None)
184
+
185
+
186
+ def split_by_title(
187
+ df: pd.DataFrame, val_fraction: float
188
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
189
+ """
190
+ Split input DF by article title ensuring no title overlap:
191
+ this is critical for model generalization to new contexts Vs
192
+ simply memorizing passages and responding to new questions about them
193
+ (e.g., when splitting the initial training set into 'train' and 'val' subsets).
194
+ """
195
+ assert 0 < val_fraction < 1, "val set fraction should be between (0, 1)."
196
+ unique_titles = df[Col.TITLE.value].drop_duplicates()
197
+ shuffled_titles = unique_titles.sample(frac=1.0, random_state=DEBUG_SEED)
198
+ num_unique_titles = len(shuffled_titles)
199
+
200
+ n_val = max(1, int(num_unique_titles * val_fraction))
201
+ val_titles = set(shuffled_titles.iloc[:n_val])
202
+ train_titles = set(shuffled_titles.iloc[n_val:])
203
+
204
+ df_val = df[df[Col.TITLE.value].isin(val_titles)].copy()
205
+ df_train = df[df[Col.TITLE.value].isin(train_titles)].copy()
206
+ print(
207
+ f"Initial split | num-train-examples: {df_train.shape[0]}; num-val-examples: {df_val.shape[0]}"
208
+ )
209
+ return df_train, df_val
src/scripts/prepare_hf_deployment.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ # Experiment directory to be used for deployment is passed in as input argument
6
+ exp_dir = Path(sys.argv[1])
7
+ deploy_dir = Path("hf_deployment")
8
+
9
+ # Safety-first: enables first seeing the changes before actually transfering files over (AWS s3 operations-like)
10
+ dry_run = "--dry-run" in sys.argv
11
+ checkpoint = deploy_dir / "checkpoint"
12
+ prefix = "[DRY RUN]" if dry_run else ""
13
+
14
+ print(f"{prefix} Create: {checkpoint}")
15
+ if not dry_run:
16
+ checkpoint.mkdir(parents=True, exist_ok=True)
17
+
18
+ # Individual files
19
+ files = [
20
+ (exp_dir / "config.json", checkpoint / "config.json"),
21
+ (exp_dir / "model/pytorch_model.bin", checkpoint / "pytorch_model.bin"),
22
+ ]
23
+ for src, dst in files:
24
+ print(f"{prefix} Copy: {src} -> {dst}")
25
+ if not dry_run:
26
+ shutil.copy2(src, dst)
27
+
28
+ # Directories (recursively)
29
+ trees = [
30
+ (exp_dir / "model/tokenizer", checkpoint / "tokenizer"),
31
+ (Path("src"), deploy_dir / "src"),
32
+ ]
33
+ for src, dst in trees:
34
+ print(f"{prefix} Copy tree: {src} -> {dst}")
35
+ if not dry_run:
36
+ shutil.copytree(src, dst, dirs_exist_ok=True)
37
+
38
+ if not dry_run:
39
+ print(f"\nDeployment files are ready under {deploy_dir}.")
src/utils/__init__.py ADDED
File without changes
src/utils/constants.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Column name constants for the SQuAD v2.0 DataFrame & raw input field names.
3
+
4
+ Benefits:
5
+ - Single source of truth: schema changes are centralized
6
+ - Safety: typos are caught at definition time rather than scattered string literals
7
+ - IDE support: `Col.` autocompletes all valid names, streamlining typing and making schemas self-documenting
8
+ """
9
+
10
+ from enum import Enum
11
+ from pathlib import Path
12
+
13
+ # constants.py lives at: <repo>/src/utils/constants.py;
14
+ # resolve() addresses symlink issues
15
+ REPO_ROOT: Path = Path(__file__).resolve().parent.parent.parent
16
+ DATA_DIR: Path = REPO_ROOT / "data"
17
+ # TODO - Placeholder needs to be made smaller for experiments!
18
+ TRAIN_DATA_PATH: Path = DATA_DIR / "train-v2.0.json"
19
+ DEV_DATA_PATH: Path = DATA_DIR / "dev-v2.0.json"
20
+ EXPERIMENTS_DIR: Path = REPO_ROOT / "experiments"
21
+
22
+ DEBUG_SEED = 42
23
+
24
+
25
+ class Col(Enum):
26
+ # Schema entries below are reused for raw keys with identical names
27
+ TITLE = "title"
28
+ QUESTION_ID = "id"
29
+ QUESTION = "question"
30
+ CONTEXT = "context"
31
+ ANSWER_TEXTS = "answers"
32
+ ANSWER_STARTS = "answer_starts"
33
+ IS_IMPOSSIBLE = "is_impossible"
34
+ NUM_ANSWERS = "num_answers"
35
+
36
+
37
+ class RawField(Enum):
38
+ VERSION = "version"
39
+ DATA = "data"
40
+ PARAGRAPHS = "paragraphs"
41
+ QAS = "qas"
42
+ # QA-level answers (list of dicts with 'text' and 'answer_start')
43
+ ANSWERS = "answers"
44
+ ANSWER_TEXT = "text"
45
+ ANSWER_START = "answer_start"
src/utils/experiment_snapshot.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A 'snapshot' object used for experiment tracking.
3
+ Contains an experiment_name, the config used, the predictions produced,
4
+ the resulting metrics, the model parameters and optional metadata.
5
+
6
+ Benefits:
7
+ - Single, self-contained function call to persist an experiment run.
8
+ - Clean and automatic organization of experimental results facilitating model improvements.
9
+ """
10
+
11
+ import time, json, torch
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, asdict, is_dataclass
14
+ from typing import Dict, Any, Optional
15
+ from src.config.model_configs import BaseModelConfig
16
+ from src.etl.types import Prediction
17
+ from src.evaluation.metrics import Metrics
18
+ from src.models.bert_based_model import BertBasedQAModel
19
+ from src.models.base_qa_model import QAModel
20
+
21
+ DEFAULT_ENCODING = "utf-8"
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class ExperimentSnapshot:
26
+ experiment_name: str
27
+ config: BaseModelConfig
28
+ predictions: Dict[str, Prediction]
29
+ metrics: Metrics
30
+ metadata: Optional[Dict[str, Any]] = None
31
+ model: Optional[QAModel] = None # stores model reference
32
+
33
+ def _timestamped_dir(self, root: Path) -> Path:
34
+ ts = time.strftime("%Y%m%d_%H%M%S")
35
+ return root / f"{ts}_{self.experiment_name}"
36
+
37
+ def _as_config_dict(self) -> Dict[str, Any]:
38
+ return asdict(self.config) if is_dataclass(self.config) else dict(self.config)
39
+
40
+ def _manifest(self, run_id: str) -> Dict[str, Any]:
41
+ model_type = getattr(self.config, "MODEL_TYPE", None)
42
+ assert model_type is not None, "Unexpected empty model type."
43
+ mani = {
44
+ "run_id": run_id,
45
+ "experiment_name": self.experiment_name,
46
+ "model_type": model_type,
47
+ "artifacts": {
48
+ "config": "config.json",
49
+ "predictions": "predictions.json",
50
+ "metrics": "metrics.json",
51
+ "model": "model/",
52
+ },
53
+ }
54
+
55
+ # TODO - consider adding path to model checkpoints once we have those
56
+ if self.metadata:
57
+ mani["metadata"] = self.metadata # pass-through, unchanged
58
+ return mani
59
+
60
+ def save(self, experiments_root: Path = Path("experiments")) -> Path:
61
+ run_dir = self._timestamped_dir(experiments_root)
62
+ # raise error if accidentally attempting to overwrite previous run
63
+ run_dir.mkdir(parents=True, exist_ok=False)
64
+
65
+ (run_dir / "config.json").write_text(
66
+ json.dumps(self._as_config_dict(), indent=2), encoding=DEFAULT_ENCODING
67
+ )
68
+ (run_dir / "predictions.json").write_text(
69
+ json.dumps(
70
+ Prediction.flatten_predicted_answers(predictions=self.predictions),
71
+ ensure_ascii=False, # preserve original characters (e.g., accented characters etc.)
72
+ indent=2,
73
+ ),
74
+ encoding=DEFAULT_ENCODING,
75
+ )
76
+ (run_dir / "metrics.json").write_text(
77
+ json.dumps(self.metrics.export_for_exp_tracking(), indent=2),
78
+ encoding=DEFAULT_ENCODING,
79
+ )
80
+
81
+ if self.model is not None:
82
+ self._save_model(run_dir / "model")
83
+
84
+ manifest = self._manifest(run_dir.name)
85
+ (run_dir / "manifest.json").write_text(
86
+ json.dumps(manifest, indent=2), encoding=DEFAULT_ENCODING
87
+ )
88
+ return run_dir
89
+
90
+ def _save_model(self, model_path: Path) -> None:
91
+ """Save model weights and tokenizer."""
92
+ assert isinstance(
93
+ self.model, BertBasedQAModel
94
+ ), "Currently model saving is only supported for the BertBasedQAModel type."
95
+ model_path.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Save model weights
98
+ torch.save(self.model.qa_module.state_dict(), model_path / "pytorch_model.bin")
99
+ # Save tokenizer
100
+ self.model.tokenizer.save_pretrained(model_path / "tokenizer")
101
+ print(f"Model saved to {model_path}")
src/utils/tune_threshold.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal threshold tuning script that reuses existing pipeline code.
3
+ Can be run from a Jupyter notebook.
4
+
5
+ Note that tuning does NOT happen on the 'dev' set, which is considered to be
6
+ an external, unseen dataset for objective performance measurement. Threshold
7
+ tuning happens on the 'val' set (which is a small part of the SQuAD v2.0 training).
8
+ """
9
+
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from src.models.bert_based_model import BertBasedQAModel
13
+ from src.etl.squad_v2_loader import load_squad_v2_df, df_to_examples_map
14
+ from src.evaluation.evaluator import Evaluator
15
+ from src.utils.constants import TRAIN_DATA_PATH, DEV_DATA_PATH
16
+ from src.pipeline.qa_runner import split_by_title, DEFAULT_VAL_SET_FRACTION
17
+
18
+
19
+ def tune_threshold_and_report_final_perf(
20
+ experiment_dir: str | Path,
21
+ config_class,
22
+ device: str,
23
+ threshold_range: np.ndarray = np.linspace(-2, 2, 9),
24
+ ):
25
+ """
26
+ Simple wrapper to tune threshold on validation (part of training) and
27
+ report final performance on dev set.
28
+ """
29
+ experiment_dir = Path(experiment_dir)
30
+ best_threshold, _, _, model = _tune_threshold_on_validation(
31
+ experiment_dir=experiment_dir,
32
+ config_class=config_class,
33
+ device=device,
34
+ threshold_range=threshold_range,
35
+ )
36
+
37
+ dev_examples = df_to_examples_map(load_squad_v2_df(DEV_DATA_PATH))
38
+ final_predictions = model.predict(dev_examples, threshold_override=best_threshold)
39
+ final_metrics = Evaluator().evaluate(final_predictions, dev_examples)
40
+ print(f"Final dev set performance: {final_metrics.export_for_exp_tracking()}")
41
+
42
+ return best_threshold, model
43
+
44
+
45
+ def _tune_threshold_on_validation(
46
+ experiment_dir: Path,
47
+ config_class,
48
+ device: str,
49
+ threshold_range: np.ndarray,
50
+ val_fraction: float = DEFAULT_VAL_SET_FRACTION,
51
+ ):
52
+ print("=" * 70)
53
+ print("THRESHOLD TUNING ON VALIDATION SET")
54
+ print("=" * 70)
55
+
56
+ model = BertBasedQAModel.load_from_experiment(
57
+ experiment_dir, config_class, device=device
58
+ )
59
+ # TODO - can also store/load the exact val question IDs used during training,
60
+ # to be even more certain that we are tuning on the exact val set
61
+ df = load_squad_v2_df(TRAIN_DATA_PATH)
62
+ _, df_val = split_by_title(df, val_fraction)
63
+ val_examples = df_to_examples_map(df_val)
64
+
65
+ print(f"\nValidation set: {len(val_examples)} examples")
66
+ print(
67
+ f"Testing {len(threshold_range)} thresholds from {threshold_range.min():.1f} to {threshold_range.max():.1f}\n"
68
+ )
69
+
70
+ # Test each threshold
71
+ best_f1 = -1
72
+ best_threshold = None
73
+ best_metrics = None
74
+ results = []
75
+
76
+ for threshold in threshold_range:
77
+ predictions = model.predict(val_examples, threshold_override=threshold)
78
+ metrics = Evaluator().evaluate(predictions, val_examples)
79
+
80
+ results.append(
81
+ {"threshold": threshold, "em": metrics.exact_score, "f1": metrics.f1_score}
82
+ )
83
+
84
+ print(
85
+ f"Threshold: {threshold:6.2f} | EM: {metrics.exact_score:5.2f}% | F1: {metrics.f1_score:5.2f}%"
86
+ )
87
+
88
+ if metrics.f1_score > best_f1:
89
+ best_f1 = metrics.f1_score
90
+ best_threshold = threshold
91
+ best_metrics = metrics
92
+
93
+ # Type assertion
94
+ assert best_metrics is not None, "No thresholds tested!"
95
+
96
+ print("\n" + "=" * 70)
97
+ print("BEST THRESHOLD")
98
+ print("=" * 70)
99
+ print(f"Threshold: {best_threshold:.2f}")
100
+ print(f"EM: {best_metrics.exact_score:.2f}%")
101
+ print(f"F1: {best_metrics.f1_score:.2f}%")
102
+ print("=" * 70)
103
+
104
+ return best_threshold, best_metrics, results, model