eval_framework / evaluators /aggregate.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Roll up per-session and per-QA evaluations into baseline-level summaries.
Recall & correctness: per-session average (not pooled cumulative).
Interference: pooled across sessions.
QA & evidence: pooled across questions.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
def _safe_div(a: float, b: float) -> float:
return a / b if b else 0.0
def aggregate_metrics(
baseline_id: str,
*,
session_evaluations: Sequence[Mapping[str, object]] = (),
qa_evaluations: Sequence[Mapping[str, object]] = (),
) -> dict[str, object]:
"""Aggregate all per-session and per-QA evaluations."""
# --- Per-session recall (average) ---
recall_scores: list[float] = []
update_recall_scores: list[float] = []
# --- Per-session correctness (average) ---
correctness_scores: list[float] = []
hallucination_scores: list[float] = []
irrelevant_scores: list[float] = []
# --- Update handling (pooled) ---
upd_num_updated = 0
upd_num_both = 0
upd_num_outdated = 0
upd_total_items = 0
# --- Interference rejection (pooled) ---
interf_num_rejected = 0
interf_num_memorized = 0
interf_total_items = 0
# --- Per-session detail counters (for reference) ---
total_gold_points = 0
total_covered = 0
total_memories = 0
total_correct = 0
total_hallucination = 0
total_irrelevant = 0
for s in session_evaluations:
# Recall: per-session score
r = s.get("recall")
if r is not None:
recall_scores.append(float(r))
ur = s.get("update_recall")
if ur is not None:
update_recall_scores.append(float(ur))
# Correctness: per-session score
cr = s.get("correctness_rate")
if cr is not None:
correctness_scores.append(float(cr))
nm = int(s.get("num_memories", 0))
if nm > 0:
hallucination_scores.append(
float(s.get("num_hallucination", 0)) / nm
)
irrelevant_scores.append(
float(s.get("num_irrelevant", 0)) / nm
)
# Detail counters
c = s.get("covered_count")
if c is not None:
total_covered += int(c)
total_gold_points += int(s.get("num_gold", 0))
total_memories += nm
total_correct += int(s.get("num_correct", 0))
total_hallucination += int(s.get("num_hallucination", 0))
total_irrelevant += int(s.get("num_irrelevant", 0))
# Update handling (pooled)
upd_num_updated += int(s.get("update_num_updated", 0))
upd_num_both += int(s.get("update_num_both", 0))
upd_num_outdated += int(s.get("update_num_outdated", 0))
upd_total_items += int(s.get("update_total_items", 0))
# Interference rejection (pooled)
interf_num_rejected += int(s.get("interference_num_rejected", 0))
interf_num_memorized += int(s.get("interference_num_memorized", 0))
interf_total_items += int(s.get("interference_total_items", 0))
# --- QA (pooled) ---
qa_total = 0
qa_valid = 0
qa_correct = 0
qa_hallucination = 0
qa_omission = 0
evidence_covered = 0
evidence_total = 0
for q in qa_evaluations:
qa_total += 1
label = q.get("answer_label")
if label in ("Correct", "Hallucination", "Omission"):
qa_valid += 1
if label == "Correct":
qa_correct += 1
elif label == "Hallucination":
qa_hallucination += 1
elif label == "Omission":
qa_omission += 1
ec = q.get("evidence_covered_count")
if ec is not None:
evidence_covered += int(ec)
evidence_total += int(q.get("num_evidence", 0))
n_recall = len(recall_scores)
n_update = len(update_recall_scores)
n_correct = len(correctness_scores)
n_hallu = len(hallucination_scores)
n_irrel = len(irrelevant_scores)
return {
"baseline_id": baseline_id,
"memory_recall": {
"avg_recall": _safe_div(sum(recall_scores), n_recall),
"avg_update_recall": _safe_div(sum(update_recall_scores), n_update),
"num_sessions_with_recall": n_recall,
"num_sessions_with_update": n_update,
"total_covered": total_covered,
"total_gold": total_gold_points,
},
"memory_correctness": {
"avg_correctness": _safe_div(sum(correctness_scores), n_correct),
"avg_hallucination": _safe_div(sum(hallucination_scores), n_hallu),
"avg_irrelevant": _safe_div(sum(irrelevant_scores), n_irrel),
"num_sessions": n_correct,
"total_memories": total_memories,
"total_correct": total_correct,
"total_hallucination": total_hallucination,
"total_irrelevant": total_irrelevant,
},
"update_handling": {
"score": _safe_div(upd_num_updated * 1.0 + upd_num_both * 0.5, upd_total_items),
"num_updated": upd_num_updated,
"num_both": upd_num_both,
"num_outdated": upd_num_outdated,
"num_total": upd_total_items,
},
"interference_rejection": {
"score": _safe_div(interf_num_rejected, interf_total_items),
"num_rejected": interf_num_rejected,
"num_memorized": interf_num_memorized,
"num_total": interf_total_items,
},
"question_answering": {
"correct_ratio": _safe_div(qa_correct, qa_valid),
"hallucination_ratio": _safe_div(qa_hallucination, qa_valid),
"omission_ratio": _safe_div(qa_omission, qa_valid),
"num_total": qa_total,
"num_valid": qa_valid,
},
"evidence_coverage": {
"hit_rate": _safe_div(evidence_covered, evidence_total),
"num_covered": evidence_covered,
"num_total": evidence_total,
},
}