Spaces:
Sleeping
Sleeping
| """ | |
| Runs evaluation (RQ1–RQ4, statistical tests, plots) on previously annotated | |
| pipeline outputs that include `human_correct` and `human_faithful`. | |
| Assumes outputs were generated using `separate_for_annotation.py` and | |
| subsequently annotated. | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import itertools | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| from evaluation.stats import ( | |
| corr_ci, | |
| wilcoxon_signed_rank, | |
| holm_bonferroni, | |
| conditional_failure_rate, | |
| chi2_error_propagation, | |
| delta_metric, | |
| ) | |
| from evaluation.utils.logger import init_logging | |
| def read_jsonl(path: Path): | |
| with path.open() as f: | |
| return [json.loads(line) for line in f] | |
| def save_yaml(path: Path, obj: dict): | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(yaml.safe_dump(obj, sort_keys=False)) | |
| def agg_mean(rows: list[dict]) -> dict: | |
| keys = rows[0]["metrics"].keys() | |
| return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys} | |
| def rq1_correlation(rows): | |
| if "human_correct" not in rows[0] or rows[0]["human_correct"] is None: | |
| return {} | |
| retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}] | |
| gold = [1.0 if r["human_correct"] else 0.0 for r in rows] | |
| out = {} | |
| for k in retrieval_keys: | |
| vec = [r["metrics"][k] for r in rows] | |
| r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95) | |
| out[k] = dict(r=r, ci=[lo, hi], p=p) | |
| return out | |
| def rq2_faithfulness(rows): | |
| if "human_faithful" not in rows[0] or rows[0]["human_faithful"] is None: | |
| return {} | |
| faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))] | |
| gold = [r["human_faithful"] for r in rows] | |
| out = {} | |
| for k in faith_keys: | |
| vec = [r["metrics"][k] for r in rows] | |
| r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95) | |
| out[k] = dict(r=r, ci=[lo, hi], p=p) | |
| return out | |
| def rq3_error_propagation(rows): | |
| if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]: | |
| return {} | |
| ret_err = [r["retrieval_error"] for r in rows] | |
| halluc = [r["hallucination"] for r in rows] | |
| return { | |
| "conditional": conditional_failure_rate(ret_err, halluc), | |
| "chi2": chi2_error_propagation(ret_err, halluc), | |
| } | |
| def rq4_robustness(orig_rows, pert_rows): | |
| if pert_rows is None: | |
| return {} | |
| metrics = orig_rows[0]["metrics"].keys() | |
| out = {} | |
| for m in metrics: | |
| d, eff = delta_metric( | |
| [r["metrics"][m] for r in orig_rows], | |
| [r["metrics"][m] for r in pert_rows], | |
| ) | |
| out[m] = dict(delta=d, cohen_d=eff) | |
| return out | |
| def scatter_mrr_vs_correct(rows, path: Path): | |
| x = [r["metrics"].get("mrr", np.nan) for r in rows] | |
| y = [1 if r.get("human_correct") else 0 for r in rows] | |
| plt.figure() | |
| plt.scatter(x, y, alpha=0.5) | |
| plt.xlabel("MRR"); plt.ylabel("Correct (1)") | |
| plt.title("MRR vs. Human Correctness") | |
| plt.tight_layout(); plt.savefig(path); plt.close() | |
| def main(argv=None): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--results", nargs="+", type=Path, required=True, | |
| help="One or more annotated results.jsonl files.") | |
| ap.add_argument("--outdir", type=Path, default=Path("outputs/grid")) | |
| ap.add_argument("--perturbed-suffix", default="_pert.jsonl", | |
| help="Looks for this perturbed variant for RQ4.") | |
| ap.add_argument("--plots", action="store_true") | |
| args = ap.parse_args(argv) | |
| init_logging(log_dir=args.outdir / "logs", level="INFO") | |
| log = logging.getLogger("resume") | |
| historical = {} | |
| for res_path in args.results: | |
| cfg_name = res_path.parent.name | |
| dataset_name = res_path.parent.parent.name | |
| log.info("Processing %s on %s", cfg_name, dataset_name) | |
| rows = read_jsonl(res_path) | |
| pert_path = res_path.with_name(res_path.stem.replace("unlabeled", "pert") + args.perturbed_suffix) | |
| pert_rows = read_jsonl(pert_path) if pert_path.exists() else None | |
| run_dir = args.outdir / dataset_name / cfg_name | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| save_yaml(run_dir / "aggregates.yaml", agg_mean(rows)) | |
| save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows)) | |
| save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows)) | |
| save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows)) | |
| if pert_rows: | |
| save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows)) | |
| if args.plots: | |
| scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png") | |
| historical[cfg_name] = rows | |
| # Pairwise Wilcoxon + Holm correction | |
| if len(historical) > 1: | |
| names = list(historical) | |
| pairs = {} | |
| for a, b in itertools.combinations(names, 2): | |
| x = [r["metrics"]["rag_score"] for r in historical[a]] | |
| y = [r["metrics"]["rag_score"] for r in historical[b]] | |
| _, p = wilcoxon_signed_rank(x, y) | |
| pairs[f"{a}~{b}"] = p | |
| dataset_name = args.results[0].parent.parent.name | |
| save_yaml(args.outdir / dataset_name / "wilcoxon_rag_raw.yaml", pairs) | |
| save_yaml(args.outdir / dataset_name / "wilcoxon_rag_holm.yaml", holm_bonferroni(pairs)) | |
| log.info("Pairwise significance testing complete (rag_score).") | |
| if __name__ == "__main__": | |
| main() | |