| """Roll up per-session and per-QA evaluations into baseline-level summaries. |
| |
| Recall & correctness: per-session average (not pooled cumulative). |
| Interference: pooled across sessions. |
| QA & evidence: pooled across questions. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Mapping, Sequence |
|
|
|
|
| def _safe_div(a: float, b: float) -> float: |
| return a / b if b else 0.0 |
|
|
|
|
| def aggregate_metrics( |
| baseline_id: str, |
| *, |
| session_evaluations: Sequence[Mapping[str, object]] = (), |
| qa_evaluations: Sequence[Mapping[str, object]] = (), |
| ) -> dict[str, object]: |
| """Aggregate all per-session and per-QA evaluations.""" |
|
|
| |
| recall_scores: list[float] = [] |
| update_recall_scores: list[float] = [] |
|
|
| |
| correctness_scores: list[float] = [] |
| hallucination_scores: list[float] = [] |
| irrelevant_scores: list[float] = [] |
|
|
| |
| upd_num_updated = 0 |
| upd_num_both = 0 |
| upd_num_outdated = 0 |
| upd_total_items = 0 |
|
|
| |
| interf_num_rejected = 0 |
| interf_num_memorized = 0 |
| interf_total_items = 0 |
|
|
| |
| total_gold_points = 0 |
| total_covered = 0 |
| total_memories = 0 |
| total_correct = 0 |
| total_hallucination = 0 |
| total_irrelevant = 0 |
|
|
| for s in session_evaluations: |
| |
| r = s.get("recall") |
| if r is not None: |
| recall_scores.append(float(r)) |
|
|
| ur = s.get("update_recall") |
| if ur is not None: |
| update_recall_scores.append(float(ur)) |
|
|
| |
| cr = s.get("correctness_rate") |
| if cr is not None: |
| correctness_scores.append(float(cr)) |
|
|
| nm = int(s.get("num_memories", 0)) |
| if nm > 0: |
| hallucination_scores.append( |
| float(s.get("num_hallucination", 0)) / nm |
| ) |
| irrelevant_scores.append( |
| float(s.get("num_irrelevant", 0)) / nm |
| ) |
|
|
| |
| c = s.get("covered_count") |
| if c is not None: |
| total_covered += int(c) |
| total_gold_points += int(s.get("num_gold", 0)) |
| total_memories += nm |
| total_correct += int(s.get("num_correct", 0)) |
| total_hallucination += int(s.get("num_hallucination", 0)) |
| total_irrelevant += int(s.get("num_irrelevant", 0)) |
|
|
| |
| upd_num_updated += int(s.get("update_num_updated", 0)) |
| upd_num_both += int(s.get("update_num_both", 0)) |
| upd_num_outdated += int(s.get("update_num_outdated", 0)) |
| upd_total_items += int(s.get("update_total_items", 0)) |
|
|
| |
| interf_num_rejected += int(s.get("interference_num_rejected", 0)) |
| interf_num_memorized += int(s.get("interference_num_memorized", 0)) |
| interf_total_items += int(s.get("interference_total_items", 0)) |
|
|
| |
| qa_total = 0 |
| qa_valid = 0 |
| qa_correct = 0 |
| qa_hallucination = 0 |
| qa_omission = 0 |
| evidence_covered = 0 |
| evidence_total = 0 |
|
|
| for q in qa_evaluations: |
| qa_total += 1 |
| label = q.get("answer_label") |
| if label in ("Correct", "Hallucination", "Omission"): |
| qa_valid += 1 |
| if label == "Correct": |
| qa_correct += 1 |
| elif label == "Hallucination": |
| qa_hallucination += 1 |
| elif label == "Omission": |
| qa_omission += 1 |
|
|
| ec = q.get("evidence_covered_count") |
| if ec is not None: |
| evidence_covered += int(ec) |
| evidence_total += int(q.get("num_evidence", 0)) |
|
|
| n_recall = len(recall_scores) |
| n_update = len(update_recall_scores) |
| n_correct = len(correctness_scores) |
| n_hallu = len(hallucination_scores) |
| n_irrel = len(irrelevant_scores) |
|
|
| return { |
| "baseline_id": baseline_id, |
| "memory_recall": { |
| "avg_recall": _safe_div(sum(recall_scores), n_recall), |
| "avg_update_recall": _safe_div(sum(update_recall_scores), n_update), |
| "num_sessions_with_recall": n_recall, |
| "num_sessions_with_update": n_update, |
| "total_covered": total_covered, |
| "total_gold": total_gold_points, |
| }, |
| "memory_correctness": { |
| "avg_correctness": _safe_div(sum(correctness_scores), n_correct), |
| "avg_hallucination": _safe_div(sum(hallucination_scores), n_hallu), |
| "avg_irrelevant": _safe_div(sum(irrelevant_scores), n_irrel), |
| "num_sessions": n_correct, |
| "total_memories": total_memories, |
| "total_correct": total_correct, |
| "total_hallucination": total_hallucination, |
| "total_irrelevant": total_irrelevant, |
| }, |
| "update_handling": { |
| "score": _safe_div(upd_num_updated * 1.0 + upd_num_both * 0.5, upd_total_items), |
| "num_updated": upd_num_updated, |
| "num_both": upd_num_both, |
| "num_outdated": upd_num_outdated, |
| "num_total": upd_total_items, |
| }, |
| "interference_rejection": { |
| "score": _safe_div(interf_num_rejected, interf_total_items), |
| "num_rejected": interf_num_rejected, |
| "num_memorized": interf_num_memorized, |
| "num_total": interf_total_items, |
| }, |
| "question_answering": { |
| "correct_ratio": _safe_div(qa_correct, qa_valid), |
| "hallucination_ratio": _safe_div(qa_hallucination, qa_valid), |
| "omission_ratio": _safe_div(qa_omission, qa_valid), |
| "num_total": qa_total, |
| "num_valid": qa_valid, |
| }, |
| "evidence_coverage": { |
| "hit_rate": _safe_div(evidence_covered, evidence_total), |
| "num_covered": evidence_covered, |
| "num_total": evidence_total, |
| }, |
| } |
|
|