gec-inline / scripts /qualitative_table.py
Lopato4ka's picture
Upload folder using huggingface_hub
32387a0 verified
Raw
History Blame Contribute Delete
3.27 kB
"""Produce a side-by-side qualitative comparison table for the report.
Reads N prediction JSONL files (one per checkpoint) plus the eval JSONL
with gold targets, picks 15 cherry-picked "hard" examples (longest edits
or examples where models disagree) and 15 random examples, then writes
a markdown table.
Usage::
python -m scripts.qualitative_table \
--eval data/processed/eval_bea_dev.jsonl \
--pred base=results/predictions/base_3shot_bea_dev.jsonl \
--pred sft=results/predictions/sft_bea_dev.jsonl \
--pred dpo=results/predictions/dpo_bea_dev.jsonl \
--out results/qualitative.md \
--seed 3407
"""
from __future__ import annotations
import argparse
import json
import random
from pathlib import Path
def read_jsonl(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8") as f:
return [json.loads(line) for line in f if line.strip()]
def md_escape(s: str) -> str:
return s.replace("|", "\\|").replace("\n", " ")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--eval", required=True, type=Path)
ap.add_argument(
"--pred", action="append", required=True,
help="NAME=path to predictions JSONL. Pass multiple times.",
)
ap.add_argument("--out", required=True, type=Path)
ap.add_argument("--n-hard", type=int, default=15)
ap.add_argument("--n-random", type=int, default=15)
ap.add_argument("--seed", type=int, default=3407)
args = ap.parse_args()
gold = read_jsonl(args.eval)
preds: dict[str, list[dict]] = {}
for spec in args.pred:
name, path = spec.split("=", 1)
preds[name] = read_jsonl(Path(path))
if len(preds[name]) != len(gold):
raise SystemExit(
f"{name}: {len(preds[name])} predictions vs {len(gold)} gold rows")
rng = random.Random(args.seed)
n = len(gold)
# Hardness heuristic: sentences where (a) gold has many edits AND
# (b) at least one model disagrees with another model's output.
def disagree(i: int) -> int:
outs = {preds[name][i]["corrected"].strip() for name in preds}
return len(outs)
def edit_count(i: int) -> int:
return gold[i].get("completion", "").count("{")
indices = list(range(n))
indices.sort(key=lambda i: (-disagree(i), -edit_count(i)))
hard = indices[: args.n_hard]
pool = [i for i in range(n) if i not in set(hard)]
rng.shuffle(pool)
rand = pool[: args.n_random]
selected = [("hard", i) for i in hard] + [("random", i) for i in rand]
cols = ["#", "kind", "source", "gold"] + list(preds.keys())
out_lines = ["| " + " | ".join(cols) + " |", "| " + " | ".join(["---"] * len(cols)) + " |"]
for row_no, (kind, i) in enumerate(selected, 1):
g = gold[i]
row = [str(row_no), kind, md_escape(g["source"]), md_escape(g.get("completion", g.get("target", "")))]
for name in preds:
row.append(md_escape(preds[name][i]["raw"]))
out_lines.append("| " + " | ".join(row) + " |")
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text("\n".join(out_lines) + "\n", encoding="utf-8")
print(f"Wrote {len(selected)} rows to {args.out}")
if __name__ == "__main__":
main()