File size: 8,007 Bytes
85b19cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """Unified session evaluation: recall + correctness (includes update & interference).
Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only:
Call 1 — Recall: how many of this session's gold points are covered by the
session's memory delta (add/update ops)?
Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant?
(reference = this session's gold points + interference)
Aggregate: per-session recall/correctness averaged across sessions.
"""
from __future__ import annotations
from eval_framework.judges import (
evaluate_correctness_batch,
evaluate_interference_single,
evaluate_recall_batch,
evaluate_update_single,
)
from eval_framework.pipeline.records import PipelineSessionRecord
def _delta_to_text(session: PipelineSessionRecord) -> str:
"""Only the memories added or updated in THIS session (not the full snapshot)."""
lines: list[str] = []
idx = 0
for d in session.memory_delta:
if d.op in ("add", "update"):
idx += 1
lines.append(f"[{idx}] {d.text}")
return "\n".join(lines)
def _delta_texts(session: PipelineSessionRecord) -> list[str]:
"""Text list of memories added or updated in THIS session."""
return [d.text for d in session.memory_delta if d.op in ("add", "update")]
def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]:
"""Current session's new + update gold points only (NOT cumulative)."""
out: list[str] = []
for g in session.gold_state.session_new_memories:
out.append(f"[normal] {g.memory_content}")
for g in session.gold_state.session_update_memories:
out.append(f"[update] {g.memory_content}")
return out
def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]:
"""Current session's new + update + interference gold points as reference."""
out: list[str] = []
for g in session.gold_state.session_new_memories:
out.append(f"[normal] {g.memory_content}")
for g in session.gold_state.session_update_memories:
out.append(f"[update] {g.memory_content}")
for g in session.gold_state.session_interference_memories:
out.append(f"[interference] {g.memory_content}")
return out
def evaluate_extraction(
session: PipelineSessionRecord,
**_kwargs: object,
) -> dict[str, object]:
"""Unified session evaluation: recall + correctness in 2 LLM calls.
Uses only THIS session's new gold points for recall and correctness,
not the cumulative history. Aggregate averages per-session scores.
"""
delta_str = _delta_to_text(session)
delta_texts = _delta_texts(session)
interference_total = len(session.gold_state.session_interference_memories)
# --- Call 1: Recall (this session's gold points vs this session's delta) ---
recall_gold = _build_recall_gold_points(session)
if not recall_gold:
recall = None
update_recall = None
recall_result: dict[str, object] = {
"covered_count": 0, "update_covered_count": 0,
"total": 0, "update_total": 0,
"reasoning": "No new gold points in this session.",
}
elif not delta_str.strip():
recall = 0.0
update_recall = 0.0
update_total = sum(1 for p in recall_gold if p.startswith("[update]"))
recall_result = {
"covered_count": 0, "update_covered_count": 0,
"total": len(recall_gold), "update_total": update_total,
"reasoning": "No add/update memories in this session's delta.",
}
else:
recall_result = evaluate_recall_batch(delta_str, recall_gold)
covered = recall_result.get("covered_count")
upd_covered = recall_result.get("update_covered_count")
total_gold = recall_result.get("total", len(recall_gold))
upd_total = recall_result.get("update_total", 0)
if recall_gold:
recall = float(covered) / float(total_gold) if covered is not None and total_gold else None
update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None
# --- Call 2: Correctness (this session's delta memories, reference = this session's golds) ---
correctness_gold = _build_correctness_gold_points(session)
correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total)
correctness_records = correctness_result.get("results", [])
num_correct = sum(1 for r in correctness_records if r.get("label") == "correct")
num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination")
num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant")
num_memories = len(delta_texts)
correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0
# --- Call 3+: Update handling (one LLM call per update gold point) ---
update_records: list[dict[str, object]] = []
for g in session.gold_state.session_update_memories:
res = evaluate_update_single(
delta_str,
new_content=g.memory_content,
old_contents=list(g.original_memories),
)
update_records.append({
"memory_id": g.memory_id,
"label": res["label"],
"reasoning": res["reasoning"],
})
num_updated = sum(1 for r in update_records if r["label"] == "updated")
num_both = sum(1 for r in update_records if r["label"] == "both")
num_outdated = sum(1 for r in update_records if r["label"] == "outdated")
update_total_items = len(update_records)
# Score: updated=1.0, both=0.5, outdated=0.0
update_score = (
(num_updated * 1.0 + num_both * 0.5) / update_total_items
if update_total_items else None
)
# --- Call 4+: Interference rejection (one LLM call per interference gold point) ---
interference_records: list[dict[str, object]] = []
for g in session.gold_state.session_interference_memories:
res = evaluate_interference_single(
delta_str,
interference_content=g.memory_content,
)
interference_records.append({
"memory_id": g.memory_id,
"label": res["label"],
"reasoning": res["reasoning"],
})
num_rejected = sum(1 for r in interference_records if r["label"] == "rejected")
num_memorized = sum(1 for r in interference_records if r["label"] == "memorized")
interference_total_items = len(interference_records)
# Score: rejected=1.0, memorized=0.0
interference_score = (
float(num_rejected) / interference_total_items
if interference_total_items else None
)
return {
"session_id": session.session_id,
"recall": recall,
"covered_count": covered,
"num_gold": total_gold,
"update_recall": update_recall,
"update_covered_count": upd_covered,
"update_total": upd_total,
"recall_reasoning": recall_result.get("reasoning", ""),
"correctness_rate": correctness_rate,
"num_memories": num_memories,
"num_correct": num_correct,
"num_hallucination": num_hallucination,
"num_irrelevant": num_irrelevant,
"correctness_reasoning": correctness_result.get("reasoning", ""),
"correctness_records": correctness_records,
# Update handling
"update_score": update_score,
"update_num_updated": num_updated,
"update_num_both": num_both,
"update_num_outdated": num_outdated,
"update_total_items": update_total_items,
"update_records": update_records,
# Interference rejection
"interference_score": interference_score,
"interference_num_rejected": num_rejected,
"interference_num_memorized": num_memorized,
"interference_total_items": interference_total_items,
"interference_records": interference_records,
}
|