|
|
""" |
|
|
Minimal threshold tuning script that reuses existing pipeline code. |
|
|
Can be run from a Jupyter notebook. |
|
|
|
|
|
Note that tuning does NOT happen on the 'dev' set, which is considered to be |
|
|
an external, unseen dataset for objective performance measurement. Threshold |
|
|
tuning happens on the 'val' set (which is a small part of the SQuAD v2.0 training). |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from src.models.bert_based_model import BertBasedQAModel |
|
|
from src.etl.squad_v2_loader import load_squad_v2_df, df_to_examples_map |
|
|
from src.evaluation.evaluator import Evaluator |
|
|
from src.utils.constants import TRAIN_DATA_PATH, DEV_DATA_PATH |
|
|
from src.pipeline.qa_runner import split_by_title, DEFAULT_VAL_SET_FRACTION |
|
|
|
|
|
|
|
|
def tune_threshold_and_report_final_perf( |
|
|
experiment_dir: str | Path, |
|
|
config_class, |
|
|
device: str, |
|
|
threshold_range: np.ndarray = np.linspace(-2, 2, 9), |
|
|
): |
|
|
""" |
|
|
Simple wrapper to tune threshold on validation (part of training) and |
|
|
report final performance on dev set. |
|
|
""" |
|
|
experiment_dir = Path(experiment_dir) |
|
|
best_threshold, _, _, model = _tune_threshold_on_validation( |
|
|
experiment_dir=experiment_dir, |
|
|
config_class=config_class, |
|
|
device=device, |
|
|
threshold_range=threshold_range, |
|
|
) |
|
|
|
|
|
dev_examples = df_to_examples_map(load_squad_v2_df(DEV_DATA_PATH)) |
|
|
final_predictions = model.predict(dev_examples, threshold_override=best_threshold) |
|
|
final_metrics = Evaluator().evaluate(final_predictions, dev_examples) |
|
|
print(f"Final dev set performance: {final_metrics.export_for_exp_tracking()}") |
|
|
|
|
|
return best_threshold, model |
|
|
|
|
|
|
|
|
def _tune_threshold_on_validation( |
|
|
experiment_dir: Path, |
|
|
config_class, |
|
|
device: str, |
|
|
threshold_range: np.ndarray, |
|
|
val_fraction: float = DEFAULT_VAL_SET_FRACTION, |
|
|
): |
|
|
print("=" * 70) |
|
|
print("THRESHOLD TUNING ON VALIDATION SET") |
|
|
print("=" * 70) |
|
|
|
|
|
model = BertBasedQAModel.load_from_experiment( |
|
|
experiment_dir, config_class, device=device |
|
|
) |
|
|
|
|
|
|
|
|
df = load_squad_v2_df(TRAIN_DATA_PATH) |
|
|
_, df_val = split_by_title(df, val_fraction) |
|
|
val_examples = df_to_examples_map(df_val) |
|
|
|
|
|
print(f"\nValidation set: {len(val_examples)} examples") |
|
|
print( |
|
|
f"Testing {len(threshold_range)} thresholds from {threshold_range.min():.1f} to {threshold_range.max():.1f}\n" |
|
|
) |
|
|
|
|
|
|
|
|
best_f1 = -1 |
|
|
best_threshold = None |
|
|
best_metrics = None |
|
|
results = [] |
|
|
|
|
|
for threshold in threshold_range: |
|
|
predictions = model.predict(val_examples, threshold_override=threshold) |
|
|
metrics = Evaluator().evaluate(predictions, val_examples) |
|
|
|
|
|
results.append( |
|
|
{"threshold": threshold, "em": metrics.exact_score, "f1": metrics.f1_score} |
|
|
) |
|
|
|
|
|
print( |
|
|
f"Threshold: {threshold:6.2f} | EM: {metrics.exact_score:5.2f}% | F1: {metrics.f1_score:5.2f}%" |
|
|
) |
|
|
|
|
|
if metrics.f1_score > best_f1: |
|
|
best_f1 = metrics.f1_score |
|
|
best_threshold = threshold |
|
|
best_metrics = metrics |
|
|
|
|
|
|
|
|
assert best_metrics is not None, "No thresholds tested!" |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("BEST THRESHOLD") |
|
|
print("=" * 70) |
|
|
print(f"Threshold: {best_threshold:.2f}") |
|
|
print(f"EM: {best_metrics.exact_score:.2f}%") |
|
|
print(f"F1: {best_metrics.f1_score:.2f}%") |
|
|
print("=" * 70) |
|
|
|
|
|
return best_threshold, best_metrics, results, model |
|
|
|