Spaces:
Sleeping
Sleeping
File size: 5,499 Bytes
12409b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
Runs evaluation (RQ1–RQ4, statistical tests, plots) on previously annotated
pipeline outputs that include `human_correct` and `human_faithful`.
Assumes outputs were generated using `separate_for_annotation.py` and
subsequently annotated.
"""
import argparse
import json
import logging
import itertools
from pathlib import Path
import numpy as np
import yaml
import matplotlib.pyplot as plt
from evaluation.stats import (
corr_ci,
wilcoxon_signed_rank,
holm_bonferroni,
conditional_failure_rate,
chi2_error_propagation,
delta_metric,
)
from evaluation.utils.logger import init_logging
def read_jsonl(path: Path):
with path.open() as f:
return [json.loads(line) for line in f]
def save_yaml(path: Path, obj: dict):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(yaml.safe_dump(obj, sort_keys=False))
def agg_mean(rows: list[dict]) -> dict:
keys = rows[0]["metrics"].keys()
return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
def rq1_correlation(rows):
if "human_correct" not in rows[0] or rows[0]["human_correct"] is None:
return {}
retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}]
gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
out = {}
for k in retrieval_keys:
vec = [r["metrics"][k] for r in rows]
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
out[k] = dict(r=r, ci=[lo, hi], p=p)
return out
def rq2_faithfulness(rows):
if "human_faithful" not in rows[0] or rows[0]["human_faithful"] is None:
return {}
faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))]
gold = [r["human_faithful"] for r in rows]
out = {}
for k in faith_keys:
vec = [r["metrics"][k] for r in rows]
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
out[k] = dict(r=r, ci=[lo, hi], p=p)
return out
def rq3_error_propagation(rows):
if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]:
return {}
ret_err = [r["retrieval_error"] for r in rows]
halluc = [r["hallucination"] for r in rows]
return {
"conditional": conditional_failure_rate(ret_err, halluc),
"chi2": chi2_error_propagation(ret_err, halluc),
}
def rq4_robustness(orig_rows, pert_rows):
if pert_rows is None:
return {}
metrics = orig_rows[0]["metrics"].keys()
out = {}
for m in metrics:
d, eff = delta_metric(
[r["metrics"][m] for r in orig_rows],
[r["metrics"][m] for r in pert_rows],
)
out[m] = dict(delta=d, cohen_d=eff)
return out
def scatter_mrr_vs_correct(rows, path: Path):
x = [r["metrics"].get("mrr", np.nan) for r in rows]
y = [1 if r.get("human_correct") else 0 for r in rows]
plt.figure()
plt.scatter(x, y, alpha=0.5)
plt.xlabel("MRR"); plt.ylabel("Correct (1)")
plt.title("MRR vs. Human Correctness")
plt.tight_layout(); plt.savefig(path); plt.close()
def main(argv=None):
ap = argparse.ArgumentParser()
ap.add_argument("--results", nargs="+", type=Path, required=True,
help="One or more annotated results.jsonl files.")
ap.add_argument("--outdir", type=Path, default=Path("outputs/grid"))
ap.add_argument("--perturbed-suffix", default="_pert.jsonl",
help="Looks for this perturbed variant for RQ4.")
ap.add_argument("--plots", action="store_true")
args = ap.parse_args(argv)
init_logging(log_dir=args.outdir / "logs", level="INFO")
log = logging.getLogger("resume")
historical = {}
for res_path in args.results:
cfg_name = res_path.parent.name
dataset_name = res_path.parent.parent.name
log.info("Processing %s on %s", cfg_name, dataset_name)
rows = read_jsonl(res_path)
pert_path = res_path.with_name(res_path.stem.replace("unlabeled", "pert") + args.perturbed_suffix)
pert_rows = read_jsonl(pert_path) if pert_path.exists() else None
run_dir = args.outdir / dataset_name / cfg_name
run_dir.mkdir(parents=True, exist_ok=True)
save_yaml(run_dir / "aggregates.yaml", agg_mean(rows))
save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows))
save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows))
save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows))
if pert_rows:
save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows))
if args.plots:
scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png")
historical[cfg_name] = rows
# Pairwise Wilcoxon + Holm correction
if len(historical) > 1:
names = list(historical)
pairs = {}
for a, b in itertools.combinations(names, 2):
x = [r["metrics"]["rag_score"] for r in historical[a]]
y = [r["metrics"]["rag_score"] for r in historical[b]]
_, p = wilcoxon_signed_rank(x, y)
pairs[f"{a}~{b}"] = p
dataset_name = args.results[0].parent.parent.name
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_raw.yaml", pairs)
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_holm.yaml", holm_bonferroni(pairs))
log.info("Pairwise significance testing complete (rag_score).")
if __name__ == "__main__":
main()
|