File size: 3,572 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Minimal threshold tuning script that reuses existing pipeline code.
Can be run from a Jupyter notebook.

Note that tuning does NOT happen on the 'dev' set, which is considered to be
an external, unseen dataset for objective performance measurement. Threshold
tuning happens on the 'val' set (which is a small part of the SQuAD v2.0 training).
"""

import numpy as np
from pathlib import Path
from src.models.bert_based_model import BertBasedQAModel
from src.etl.squad_v2_loader import load_squad_v2_df, df_to_examples_map
from src.evaluation.evaluator import Evaluator
from src.utils.constants import TRAIN_DATA_PATH, DEV_DATA_PATH
from src.pipeline.qa_runner import split_by_title, DEFAULT_VAL_SET_FRACTION


def tune_threshold_and_report_final_perf(
    experiment_dir: str | Path,
    config_class,
    device: str,
    threshold_range: np.ndarray = np.linspace(-2, 2, 9),
):
    """
    Simple wrapper to tune threshold on validation (part of training) and
    report final performance on dev set.
    """
    experiment_dir = Path(experiment_dir)
    best_threshold, _, _, model = _tune_threshold_on_validation(
        experiment_dir=experiment_dir,
        config_class=config_class,
        device=device,
        threshold_range=threshold_range,
    )

    dev_examples = df_to_examples_map(load_squad_v2_df(DEV_DATA_PATH))
    final_predictions = model.predict(dev_examples, threshold_override=best_threshold)
    final_metrics = Evaluator().evaluate(final_predictions, dev_examples)
    print(f"Final dev set performance: {final_metrics.export_for_exp_tracking()}")

    return best_threshold, model


def _tune_threshold_on_validation(
    experiment_dir: Path,
    config_class,
    device: str,
    threshold_range: np.ndarray,
    val_fraction: float = DEFAULT_VAL_SET_FRACTION,
):
    print("=" * 70)
    print("THRESHOLD TUNING ON VALIDATION SET")
    print("=" * 70)

    model = BertBasedQAModel.load_from_experiment(
        experiment_dir, config_class, device=device
    )
    # TODO - can also store/load the exact val question IDs used during training,
    # to be even more certain that we are tuning on the exact val set
    df = load_squad_v2_df(TRAIN_DATA_PATH)
    _, df_val = split_by_title(df, val_fraction)
    val_examples = df_to_examples_map(df_val)

    print(f"\nValidation set: {len(val_examples)} examples")
    print(
        f"Testing {len(threshold_range)} thresholds from {threshold_range.min():.1f} to {threshold_range.max():.1f}\n"
    )

    # Test each threshold
    best_f1 = -1
    best_threshold = None
    best_metrics = None
    results = []

    for threshold in threshold_range:
        predictions = model.predict(val_examples, threshold_override=threshold)
        metrics = Evaluator().evaluate(predictions, val_examples)

        results.append(
            {"threshold": threshold, "em": metrics.exact_score, "f1": metrics.f1_score}
        )

        print(
            f"Threshold: {threshold:6.2f} | EM: {metrics.exact_score:5.2f}% | F1: {metrics.f1_score:5.2f}%"
        )

        if metrics.f1_score > best_f1:
            best_f1 = metrics.f1_score
            best_threshold = threshold
            best_metrics = metrics

    # Type assertion
    assert best_metrics is not None, "No thresholds tested!"

    print("\n" + "=" * 70)
    print("BEST THRESHOLD")
    print("=" * 70)
    print(f"Threshold: {best_threshold:.2f}")
    print(f"EM: {best_metrics.exact_score:.2f}%")
    print(f"F1: {best_metrics.f1_score:.2f}%")
    print("=" * 70)

    return best_threshold, best_metrics, results, model