| """Unified session evaluation: recall + correctness (includes update & interference). |
| |
| Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only: |
| Call 1 — Recall: how many of this session's gold points are covered by the |
| session's memory delta (add/update ops)? |
| Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant? |
| (reference = this session's gold points + interference) |
| |
| Aggregate: per-session recall/correctness averaged across sessions. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from eval_framework.judges import ( |
| evaluate_correctness_batch, |
| evaluate_interference_single, |
| evaluate_recall_batch, |
| evaluate_update_single, |
| ) |
| from eval_framework.pipeline.records import PipelineSessionRecord |
|
|
|
|
| def _delta_to_text(session: PipelineSessionRecord) -> str: |
| """Only the memories added or updated in THIS session (not the full snapshot).""" |
| lines: list[str] = [] |
| idx = 0 |
| for d in session.memory_delta: |
| if d.op in ("add", "update"): |
| idx += 1 |
| lines.append(f"[{idx}] {d.text}") |
| return "\n".join(lines) |
|
|
|
|
| def _delta_texts(session: PipelineSessionRecord) -> list[str]: |
| """Text list of memories added or updated in THIS session.""" |
| return [d.text for d in session.memory_delta if d.op in ("add", "update")] |
|
|
|
|
| def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]: |
| """Current session's new + update gold points only (NOT cumulative).""" |
| out: list[str] = [] |
| for g in session.gold_state.session_new_memories: |
| out.append(f"[normal] {g.memory_content}") |
| for g in session.gold_state.session_update_memories: |
| out.append(f"[update] {g.memory_content}") |
| return out |
|
|
|
|
| def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]: |
| """Current session's new + update + interference gold points as reference.""" |
| out: list[str] = [] |
| for g in session.gold_state.session_new_memories: |
| out.append(f"[normal] {g.memory_content}") |
| for g in session.gold_state.session_update_memories: |
| out.append(f"[update] {g.memory_content}") |
| for g in session.gold_state.session_interference_memories: |
| out.append(f"[interference] {g.memory_content}") |
| return out |
|
|
|
|
| def evaluate_extraction( |
| session: PipelineSessionRecord, |
| **_kwargs: object, |
| ) -> dict[str, object]: |
| """Unified session evaluation: recall + correctness in 2 LLM calls. |
| |
| Uses only THIS session's new gold points for recall and correctness, |
| not the cumulative history. Aggregate averages per-session scores. |
| """ |
|
|
| delta_str = _delta_to_text(session) |
| delta_texts = _delta_texts(session) |
| interference_total = len(session.gold_state.session_interference_memories) |
|
|
| |
| recall_gold = _build_recall_gold_points(session) |
|
|
| if not recall_gold: |
| recall = None |
| update_recall = None |
| recall_result: dict[str, object] = { |
| "covered_count": 0, "update_covered_count": 0, |
| "total": 0, "update_total": 0, |
| "reasoning": "No new gold points in this session.", |
| } |
| elif not delta_str.strip(): |
| recall = 0.0 |
| update_recall = 0.0 |
| update_total = sum(1 for p in recall_gold if p.startswith("[update]")) |
| recall_result = { |
| "covered_count": 0, "update_covered_count": 0, |
| "total": len(recall_gold), "update_total": update_total, |
| "reasoning": "No add/update memories in this session's delta.", |
| } |
| else: |
| recall_result = evaluate_recall_batch(delta_str, recall_gold) |
|
|
| covered = recall_result.get("covered_count") |
| upd_covered = recall_result.get("update_covered_count") |
| total_gold = recall_result.get("total", len(recall_gold)) |
| upd_total = recall_result.get("update_total", 0) |
|
|
| if recall_gold: |
| recall = float(covered) / float(total_gold) if covered is not None and total_gold else None |
| update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None |
|
|
| |
| correctness_gold = _build_correctness_gold_points(session) |
| correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total) |
| correctness_records = correctness_result.get("results", []) |
|
|
| num_correct = sum(1 for r in correctness_records if r.get("label") == "correct") |
| num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination") |
| num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant") |
| num_memories = len(delta_texts) |
| correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0 |
|
|
| |
| update_records: list[dict[str, object]] = [] |
| for g in session.gold_state.session_update_memories: |
| res = evaluate_update_single( |
| delta_str, |
| new_content=g.memory_content, |
| old_contents=list(g.original_memories), |
| ) |
| update_records.append({ |
| "memory_id": g.memory_id, |
| "label": res["label"], |
| "reasoning": res["reasoning"], |
| }) |
|
|
| num_updated = sum(1 for r in update_records if r["label"] == "updated") |
| num_both = sum(1 for r in update_records if r["label"] == "both") |
| num_outdated = sum(1 for r in update_records if r["label"] == "outdated") |
| update_total_items = len(update_records) |
| |
| update_score = ( |
| (num_updated * 1.0 + num_both * 0.5) / update_total_items |
| if update_total_items else None |
| ) |
|
|
| |
| interference_records: list[dict[str, object]] = [] |
| for g in session.gold_state.session_interference_memories: |
| res = evaluate_interference_single( |
| delta_str, |
| interference_content=g.memory_content, |
| ) |
| interference_records.append({ |
| "memory_id": g.memory_id, |
| "label": res["label"], |
| "reasoning": res["reasoning"], |
| }) |
|
|
| num_rejected = sum(1 for r in interference_records if r["label"] == "rejected") |
| num_memorized = sum(1 for r in interference_records if r["label"] == "memorized") |
| interference_total_items = len(interference_records) |
| |
| interference_score = ( |
| float(num_rejected) / interference_total_items |
| if interference_total_items else None |
| ) |
|
|
| return { |
| "session_id": session.session_id, |
| "recall": recall, |
| "covered_count": covered, |
| "num_gold": total_gold, |
| "update_recall": update_recall, |
| "update_covered_count": upd_covered, |
| "update_total": upd_total, |
| "recall_reasoning": recall_result.get("reasoning", ""), |
| "correctness_rate": correctness_rate, |
| "num_memories": num_memories, |
| "num_correct": num_correct, |
| "num_hallucination": num_hallucination, |
| "num_irrelevant": num_irrelevant, |
| "correctness_reasoning": correctness_result.get("reasoning", ""), |
| "correctness_records": correctness_records, |
| |
| "update_score": update_score, |
| "update_num_updated": num_updated, |
| "update_num_both": num_both, |
| "update_num_outdated": num_outdated, |
| "update_total_items": update_total_items, |
| "update_records": update_records, |
| |
| "interference_score": interference_score, |
| "interference_num_rejected": num_rejected, |
| "interference_num_memorized": num_memorized, |
| "interference_total_items": interference_total_items, |
| "interference_records": interference_records, |
| } |
|
|