"""Score model predictions against gold references. Two evaluation modes, selected by ``--mode``: 1. ``bea`` -- single-reference ERRANT F0.5 against a BEA M2 file (canonical metric for the BEA-2019 shared task). Requires the ``errant_parallel`` and ``errant_compare`` CLIs from the ``errant`` package on PATH. 2. ``jfleg`` -- multi-reference exact-match accuracy on JFLEG (sentence matches *any* of the 4 references). Doesn't need ERRANT. Both modes also report: - parse-failure rate (how often the model's output had unmatched braces), - mean edit count per prediction, - "trivial copy" rate (model emitted the source unchanged). Usage:: python -m scripts.eval --mode bea \ --predictions results/predictions/sft_bea_dev.jsonl \ --ref-m2 data/raw/wi+locness/m2/ABCN.dev.gold.bea19.m2 \ --out results/metrics/sft_bea_dev.json python -m scripts.eval --mode jfleg \ --predictions results/predictions/sft_jfleg_dev.jsonl \ --jfleg-jsonl data/processed/eval_jfleg_dev.jsonl \ --out results/metrics/sft_jfleg_dev.json """ from __future__ import annotations import argparse import json import re import subprocess import sys from pathlib import Path from gec.parse import parse_inline WS_RE = re.compile(r"\s+") def _errant_bin(name: str) -> str: """Resolve an ERRANT CLI binary, preferring the current venv.""" candidate = Path(sys.executable).parent / name return str(candidate) if candidate.exists() else name def read_jsonl(path: Path) -> list[dict]: with path.open("r", encoding="utf-8") as f: return [json.loads(line) for line in f if line.strip()] def norm(s: str) -> str: return WS_RE.sub(" ", s.strip()) # -------------- BEA / ERRANT -------------- def write_lines(path: Path, lines: list[str]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text("\n".join(lines) + "\n", encoding="utf-8") def m2_sources_in_order(m2_path: Path) -> list[str]: """Return the S-lines from the M2 file (preserves order).""" out: list[str] = [] with m2_path.open("r", encoding="utf-8") as f: for line in f: if line.startswith("S "): out.append(line[2:].rstrip("\n")) return out def errant_score(hyp_m2: Path, ref_m2: Path) -> dict: """Run errant_compare and parse out P / R / F0.5.""" cp = subprocess.run( [_errant_bin("errant_compare"), "-hyp", str(hyp_m2), "-ref", str(ref_m2)], check=True, capture_output=True, text=True, ) out = cp.stdout # Output ends with a row "TP\tFP\tFN\tPrec\tRec\tF0.5" # followed by another with the values. lines = [ln for ln in out.splitlines() if ln.strip()] header_idx = None for i, ln in enumerate(lines): tokens = ln.split() if "Prec" in tokens and "F0.5" in tokens: header_idx = i break if header_idx is None or header_idx + 1 >= len(lines): raise RuntimeError(f"Could not parse errant_compare output:\n{out}") header = lines[header_idx].split() values = lines[header_idx + 1].split() metrics = dict(zip(header, values)) return { "TP": int(metrics["TP"]), "FP": int(metrics["FP"]), "FN": int(metrics["FN"]), "P": float(metrics["Prec"]), "R": float(metrics["Rec"]), "F0.5": float(metrics["F0.5"]), "raw": out, } def eval_bea(args) -> dict: predictions = read_jsonl(args.predictions) sources_in_m2 = m2_sources_in_order(args.ref_m2) if len(predictions) != len(sources_in_m2): raise SystemExit( f"Length mismatch: {len(predictions)} predictions vs " f"{len(sources_in_m2)} sources in {args.ref_m2}. " "Generate against the same eval file used to build the M2.") for i, (pred, src) in enumerate(zip(predictions, sources_in_m2)): if norm(pred["source"]) != norm(src): raise SystemExit( f"Source order mismatch at row {i}:\n" f" pred: {pred['source']!r}\n" f" m2: {src!r}") work = args.out.parent / (args.out.stem + "_work") work.mkdir(parents=True, exist_ok=True) src_txt = work / "src.txt" hyp_txt = work / "hyp.txt" hyp_m2 = work / "hyp.m2" write_lines(src_txt, [norm(p["source"]) for p in predictions]) write_lines(hyp_txt, [norm(p["corrected"]) for p in predictions]) print(f"Building hypothesis M2 via errant_parallel -> {hyp_m2}") subprocess.run( [_errant_bin("errant_parallel"), "-orig", str(src_txt), "-cor", str(hyp_txt), "-out", str(hyp_m2)], check=True, ) print("Scoring with errant_compare …") err = errant_score(hyp_m2, args.ref_m2) parse_fail = sum(1 for p in predictions if not p["parse_ok"]) / len(predictions) trivial = sum(1 for p in predictions if norm(p["corrected"]) == norm(p["source"])) return { "mode": "bea", "n": len(predictions), "parse_failure_rate": parse_fail, "trivial_copy_rate": trivial / len(predictions), "errant": err, } # -------------- JFLEG -------------- def eval_jfleg(args) -> dict: predictions = read_jsonl(args.predictions) refs = read_jsonl(args.jfleg_jsonl) if len(predictions) != len(refs): raise SystemExit( f"Length mismatch: {len(predictions)} predictions vs {len(refs)} JFLEG rows") exact = 0 for pred, ref in zip(predictions, refs): if norm(pred["source"]) != norm(ref["source"]): raise SystemExit("Source order mismatch on JFLEG.") candidates = {norm(c) for c in ref["corrections"]} if norm(pred["corrected"]) in candidates: exact += 1 parse_fail = sum(1 for p in predictions if not p["parse_ok"]) / len(predictions) trivial = sum(1 for p in predictions if norm(p["corrected"]) == norm(p["source"])) return { "mode": "jfleg", "n": len(predictions), "exact_match_rate": exact / len(predictions), "parse_failure_rate": parse_fail, "trivial_copy_rate": trivial / len(predictions), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--mode", choices=["bea", "jfleg"], required=True) ap.add_argument("--predictions", type=Path, required=True) ap.add_argument("--ref-m2", type=Path, help="BEA M2 reference (mode=bea)") ap.add_argument("--jfleg-jsonl", type=Path, help="JFLEG eval JSONL (mode=jfleg)") ap.add_argument("--out", type=Path, required=True) args = ap.parse_args() if args.mode == "bea": if not args.ref_m2: raise SystemExit("--ref-m2 is required for mode=bea") metrics = eval_bea(args) else: if not args.jfleg_jsonl: raise SystemExit("--jfleg-jsonl is required for mode=jfleg") metrics = eval_jfleg(args) args.out.parent.mkdir(parents=True, exist_ok=True) args.out.write_text(json.dumps(metrics, indent=2), encoding="utf-8") print(json.dumps({k: v for k, v in metrics.items() if k != "errant"}, indent=2)) if "errant" in metrics: e = metrics["errant"] print(f"\nERRANT: P={e['P']:.4f} R={e['R']:.4f} F0.5={e['F0.5']:.4f}" f" (TP={e['TP']} FP={e['FP']} FN={e['FN']})") if __name__ == "__main__": main()