""" Contains supplementary routines for post-hoc validation/inspection of the results: - Additional safeguard that dev set results are reliable (external recomputation of F1/EM metrics). - Offers example-level inspection to users. """ import json import pandas as pd from pathlib import Path from src.utils.constants import Col from src.evaluation.squad_v2_official import normalize_answer, compute_exact, compute_f1 def validate_experiment(exp_dir: Path, df: pd.DataFrame) -> pd.DataFrame: """Load predictions, compute scores, validate against saved metrics.""" exp_dir = Path(exp_dir) # Load and merge predictions preds = json.loads((exp_dir / "predictions.json").read_text()) pred_series = pd.Series(preds, name="predicted_answer") df_eval = df.set_index(Col.QUESTION_ID.value).join(pred_series) assert df_eval["predicted_answer"].isna().sum() == 0, "Missing predictions" df_eval = _compute_scores(df_eval) computed_em = 100.0 * df_eval["em_score"].mean() computed_f1 = 100.0 * df_eval["f1_score"].mean() # Compare with saved saved = json.loads((exp_dir / "metrics.json").read_text()) saved_em, saved_f1 = saved["exact_score"], saved["f1_score"] print(f"\n{exp_dir.name}") print(f"Computed: EM={computed_em:.2f}%, F1={computed_f1:.2f}%") print(f"Saved: EM={saved_em:.2f}%, F1={saved_f1:.2f}%") if abs(computed_em - saved_em) < 0.01 and abs(computed_f1 - saved_f1) < 0.01: print("MATCH\n") else: print("MISMATCH - check evaluation\n") return df_eval def _compute_scores(df: pd.DataFrame) -> pd.DataFrame: """Adds em_score and f1_score columns.""" scores = [] for _, row in df.iterrows(): golds = row[Col.ANSWER_TEXTS.value] pred = row["predicted_answer"] if not golds: golds = [""] else: golds = [g for g in golds if normalize_answer(str(g))] or [""] em = max((compute_exact(g, pred) for g in golds), default=0) f1 = max((compute_f1(g, pred) for g in golds), default=0.0) scores.append((em, f1)) df = df.copy() df["em_score"] = [s[0] for s in scores] df["f1_score"] = [s[1] for s in scores] return df