amul-ai-eval / scripts /summarize_run.py
bpHigh's picture
HF Space: add charts tab
74e6b83
Raw
History Blame Contribute Delete
12.6 kB
"""Summarize a completed CeRAI run from its SQLite DB.
Walks the Conversations + TestRunDetails tables and prints per-metric
aggregates: count, latency (mean/p95), MCQ accuracy (vs. ground-truth letter
when the testcase name encodes one), and average evaluation_score if the
analyzer has already populated it.
Also writes a machine-readable JSON snapshot to
results/<run_name>/summary.json so it can be embedded in the live endpoint.
Usage:
python scripts/summarize_run.py --run-name run-primary-disk-penatibus
"""
from __future__ import annotations
import argparse
import json
import re
import sqlite3
from pathlib import Path
from statistics import mean
ROOT = Path(__file__).resolve().parent.parent
DB = ROOT / "AIEvaluationTool" / "data" / "AIEvaluationData.db"
def _fetch_run_rows(con: sqlite3.Connection, run_name: str) -> list[sqlite3.Row]:
sql = """
SELECT tr.run_id, tr.run_name, tr.status,
rd.detail_id, rd.testcase_status,
m.metric_id, m.metric_name,
tc.testcase_id, tc.testcase_name,
p.user_prompt,
r.response_text AS expected,
c.agent_response,
c.evaluation_score,
c.evaluation_reason,
c.prompt_ts, c.response_ts
FROM TestRuns tr
JOIN TestRunDetails rd ON rd.run_id = tr.run_id
JOIN Metrics m ON m.metric_id = rd.metric_id
JOIN TestCases tc ON tc.testcase_id = rd.testcase_id
LEFT JOIN Prompts p ON p.prompt_id = tc.prompt_id
LEFT JOIN Responses r ON r.response_id = tc.response_id
LEFT JOIN Conversations c ON c.detail_id = rd.detail_id
WHERE tr.run_name = ?
ORDER BY rd.detail_id
"""
return list(con.execute(sql, (run_name,)).fetchall())
def _percentile(xs: list[float], p: float) -> float:
if not xs:
return 0.0
xs = sorted(xs)
k = max(0, min(len(xs) - 1, int(round(p / 100.0 * (len(xs) - 1)))))
return xs[k]
def _elapsed_seconds(prompt_ts: str | None, response_ts: str | None) -> float | None:
if not prompt_ts or not response_ts:
return None
from datetime import datetime
return (datetime.fromisoformat(response_ts) - datetime.fromisoformat(prompt_ts)).total_seconds()
def _mcq_letter(text: str | None) -> str | None:
if not text:
return None
m = re.match(r"\s*\*?\*?([A-D])[\.\)\*\s]?", text.strip())
return m.group(1).upper() if m else None
def _is_mcq_case(testcase_name: str) -> bool:
return "-MCQ-" in testcase_name
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--run-name", required=True)
parser.add_argument("--db", default=str(DB))
args = parser.parse_args()
con = sqlite3.connect(args.db)
con.row_factory = sqlite3.Row
rows = _fetch_run_rows(con, args.run_name)
if not rows:
raise SystemExit(f"no rows found for run '{args.run_name}'")
def _dataset_of(name: str) -> str:
if name.startswith("BBK-"):
return "bbk"
if name.startswith("KCC-"):
return "kcc"
return "other"
by_metric: dict[str, list[sqlite3.Row]] = {}
for r in rows:
by_metric.setdefault(r["metric_name"], []).append(r)
by_dataset: dict[str, list[sqlite3.Row]] = {}
for r in rows:
by_dataset.setdefault(_dataset_of(r["testcase_name"]), []).append(r)
summary: dict = {
"run_name": args.run_name,
"total_testcases": len(rows),
"datasets": {ds: len(rs) for ds, rs in by_dataset.items()},
"metrics": {},
"by_dataset_metric": {},
}
print(f"\n=== Run: {args.run_name} ({len(rows)} test cases) ===\n")
for metric, mrows in sorted(by_metric.items()):
latencies = [s for s in (_elapsed_seconds(r["prompt_ts"], r["response_ts"]) for r in mrows) if s is not None]
scores = [r["evaluation_score"] for r in mrows if r["evaluation_score"] is not None]
n_completed = sum(1 for r in mrows if r["testcase_status"] == "COMPLETED")
n_failed = sum(1 for r in mrows if r["testcase_status"] == "FAILED")
# MCQ accuracy where applicable
mcq_total = mcq_correct = 0
for r in mrows:
if not _is_mcq_case(r["testcase_name"]):
continue
expected = (r["expected"] or "").strip().upper()[:1]
got = _mcq_letter(r["agent_response"])
if expected and got:
mcq_total += 1
mcq_correct += int(got == expected)
entry = {
"n": len(mrows),
"completed": n_completed,
"failed": n_failed,
"latency_sec": {
"mean": round(mean(latencies), 3) if latencies else None,
"p50": round(_percentile(latencies, 50), 3) if latencies else None,
"p95": round(_percentile(latencies, 95), 3) if latencies else None,
"min": round(min(latencies), 3) if latencies else None,
"max": round(max(latencies), 3) if latencies else None,
},
"evaluation_score": {
"n": len(scores),
"mean": round(mean(scores), 3) if scores else None,
},
"mcq_accuracy": (
None if mcq_total == 0
else {"n": mcq_total, "correct": mcq_correct, "accuracy": round(mcq_correct / mcq_total, 3)}
),
}
summary["metrics"][metric] = entry
print(f"-- {metric} --")
print(f" tests : {n_completed} completed / {n_failed} failed / {len(mrows)} total")
print(f" latency (s) : mean={entry['latency_sec']['mean']} p50={entry['latency_sec']['p50']} p95={entry['latency_sec']['p95']}")
if entry["evaluation_score"]["mean"] is not None:
print(f" eval score : mean={entry['evaluation_score']['mean']} (n={entry['evaluation_score']['n']})")
if entry["mcq_accuracy"]:
mcq = entry["mcq_accuracy"]
print(f" MCQ accuracy : {mcq['correct']}/{mcq['n']} = {mcq['accuracy']*100:.1f}%")
print()
# Per-dataset Γ— per-metric breakdown so we can talk about BBK vs KCC separately.
print("=== Per-dataset Γ— per-metric breakdown ===\n")
for ds in sorted(by_dataset):
summary["by_dataset_metric"][ds] = {}
ds_rows = by_dataset[ds]
ds_metrics: dict[str, list[sqlite3.Row]] = {}
for r in ds_rows:
ds_metrics.setdefault(r["metric_name"], []).append(r)
for metric in sorted(ds_metrics):
mrows = ds_metrics[metric]
scores = [r["evaluation_score"] for r in mrows if r["evaluation_score"] is not None]
latencies = [s for s in (_elapsed_seconds(r["prompt_ts"], r["response_ts"]) for r in mrows) if s is not None]
mcq_total = mcq_correct = 0
for r in mrows:
if not _is_mcq_case(r["testcase_name"]):
continue
exp = (r["expected"] or "").strip().upper()[:1]
got = _mcq_letter(r["agent_response"])
if exp and got:
mcq_total += 1
mcq_correct += int(got == exp)
summary["by_dataset_metric"][ds][metric] = {
"n": len(mrows),
"score_mean": round(mean(scores), 3) if scores else None,
"latency_mean": round(mean(latencies), 3) if latencies else None,
"mcq": (
None if mcq_total == 0
else {"n": mcq_total, "correct": mcq_correct,
"accuracy": round(mcq_correct / mcq_total, 3)}
),
}
print(f"-- dataset: {ds.upper()} ({len(ds_rows)} test cases) --")
for metric, info in summary["by_dataset_metric"][ds].items():
mcq = info["mcq"]
mcq_text = "" if not mcq else f", MCQ {mcq['correct']}/{mcq['n']}"
print(f" {metric:18s} n={info['n']:>3} score_mean={info['score_mean']} "
f"latency_mean={info['latency_mean']}{mcq_text}")
print()
out_dir = ROOT / "results" / args.run_name
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "summary.json").write_text(json.dumps(summary, indent=2))
print(f"Wrote {out_dir/'summary.json'}")
# results.md: a per-prompt table that a human can scan at a glance.
md_lines: list[str] = []
md_lines.append(f"# Run: `{args.run_name}`")
md_lines.append("")
md_lines.append(f"- **Total test cases**: {len(rows)}")
md_lines.append(f"- **Metrics covered**: {', '.join(sorted(by_metric))}")
md_lines.append("")
md_lines.append("## Per-metric summary")
md_lines.append("")
md_lines.append("| Metric | n | Completed | Failed | Mean score | Mean latency (s) | p95 latency (s) | MCQ accuracy |")
md_lines.append("|---|---|---|---|---|---|---|---|")
for metric in sorted(by_metric):
e = summary["metrics"][metric]
score = e["evaluation_score"]["mean"]
mcq = e["mcq_accuracy"]
mcq_text = "β€”" if not mcq else f"{mcq['correct']}/{mcq['n']} ({mcq['accuracy']*100:.0f}%)"
md_lines.append(
f"| {metric} | {e['n']} | {e['completed']} | {e['failed']} | "
f"{score if score is not None else 'β€”'} | "
f"{e['latency_sec']['mean'] or 'β€”'} | "
f"{e['latency_sec']['p95'] or 'β€”'} | {mcq_text} |"
)
md_lines.append("")
md_lines.append("## Per-prompt detail")
md_lines.append("")
md_lines.append(
"One row per testcase. Each metric column shows the **CeRAI normalized score** "
"for that metric on that prompt β€” CeRAI maps every metric to a 0–1 scale where "
"**higher = better** by convention, even when the underlying quantity is "
"directional the other way (e.g. lower latency or fewer errors). The column "
"headers note the underlying direction in parens."
)
md_lines.append("")
METRIC_COLS = [
# (column header label, db metric_name)
("TAT (lower latency ↑)", "Turn_Around_Time"),
("ErrRate (fewer errors ↑)", "Error_Rate"),
("BLEU (more overlap ↑)", "BLEU"),
("ROUGE (more overlap ↑)", "ROUGE"),
("METEOR (more overlap ↑)", "METEOR"),
("LexDiv (richer vocab ↑)", "Lexical_Diversity"),
]
# Pivot rows: keyed by the *prompt* (strip the -M{metric_id} suffix
# that scripts/build_*_testcases.py adds to give each metric its own
# testcase row). All 6 metric scores for one prompt end up on one row.
_SUFFIX_RE = re.compile(r"-M\d+$")
def _prompt_key(testcase_name: str) -> str:
return _SUFFIX_RE.sub("", testcase_name)
by_tc: dict[str, dict] = {}
for r in rows:
key = _prompt_key(r["testcase_name"])
latency = _elapsed_seconds(r["prompt_ts"], r["response_ts"])
slot = by_tc.setdefault(key, {
"prompt": r["user_prompt"],
"expected": r["expected"],
"got": r["agent_response"],
"latencies": [],
"scores": {},
})
if latency is not None:
slot["latencies"].append(latency)
slot["scores"][r["metric_name"]] = r["evaluation_score"]
def _clip(s: str | None, n: int) -> str:
s = (s or "").replace("|", "\\|").replace("\n", " ").strip()
return (s[:n] + "…") if len(s) > n else s
header = (
"| testcase | latency (s) | prompt | got | expected | "
+ " | ".join(label for label, _ in METRIC_COLS)
+ " |"
)
sep = "|---" * (5 + len(METRIC_COLS)) + "|"
md_lines.append(header)
md_lines.append(sep)
for tc in sorted(by_tc):
slot = by_tc[tc]
latency = (round(mean(slot["latencies"]), 2)
if slot["latencies"] else "β€”")
cells = [
f"`{tc}`",
str(latency),
_clip(slot["prompt"], 80),
_clip(slot["got"], 80),
_clip(slot["expected"], 50),
]
for _, metric_name in METRIC_COLS:
score = slot["scores"].get(metric_name)
cells.append(
f"{round(score, 3)}" if score is not None else "β€”"
)
md_lines.append("| " + " | ".join(cells) + " |")
(out_dir / "results.md").write_text("\n".join(md_lines))
print(f"Wrote {out_dir/'results.md'}")
if __name__ == "__main__":
main()