eval_framework / evaluators /extraction.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Unified session evaluation: recall + correctness (includes update & interference).
Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only:
Call 1 — Recall: how many of this session's gold points are covered by the
session's memory delta (add/update ops)?
Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant?
(reference = this session's gold points + interference)
Aggregate: per-session recall/correctness averaged across sessions.
"""
from __future__ import annotations
from eval_framework.judges import (
evaluate_correctness_batch,
evaluate_interference_single,
evaluate_recall_batch,
evaluate_update_single,
)
from eval_framework.pipeline.records import PipelineSessionRecord
def _delta_to_text(session: PipelineSessionRecord) -> str:
"""Only the memories added or updated in THIS session (not the full snapshot)."""
lines: list[str] = []
idx = 0
for d in session.memory_delta:
if d.op in ("add", "update"):
idx += 1
lines.append(f"[{idx}] {d.text}")
return "\n".join(lines)
def _delta_texts(session: PipelineSessionRecord) -> list[str]:
"""Text list of memories added or updated in THIS session."""
return [d.text for d in session.memory_delta if d.op in ("add", "update")]
def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]:
"""Current session's new + update gold points only (NOT cumulative)."""
out: list[str] = []
for g in session.gold_state.session_new_memories:
out.append(f"[normal] {g.memory_content}")
for g in session.gold_state.session_update_memories:
out.append(f"[update] {g.memory_content}")
return out
def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]:
"""Current session's new + update + interference gold points as reference."""
out: list[str] = []
for g in session.gold_state.session_new_memories:
out.append(f"[normal] {g.memory_content}")
for g in session.gold_state.session_update_memories:
out.append(f"[update] {g.memory_content}")
for g in session.gold_state.session_interference_memories:
out.append(f"[interference] {g.memory_content}")
return out
def evaluate_extraction(
session: PipelineSessionRecord,
**_kwargs: object,
) -> dict[str, object]:
"""Unified session evaluation: recall + correctness in 2 LLM calls.
Uses only THIS session's new gold points for recall and correctness,
not the cumulative history. Aggregate averages per-session scores.
"""
delta_str = _delta_to_text(session)
delta_texts = _delta_texts(session)
interference_total = len(session.gold_state.session_interference_memories)
# --- Call 1: Recall (this session's gold points vs this session's delta) ---
recall_gold = _build_recall_gold_points(session)
if not recall_gold:
recall = None
update_recall = None
recall_result: dict[str, object] = {
"covered_count": 0, "update_covered_count": 0,
"total": 0, "update_total": 0,
"reasoning": "No new gold points in this session.",
}
elif not delta_str.strip():
recall = 0.0
update_recall = 0.0
update_total = sum(1 for p in recall_gold if p.startswith("[update]"))
recall_result = {
"covered_count": 0, "update_covered_count": 0,
"total": len(recall_gold), "update_total": update_total,
"reasoning": "No add/update memories in this session's delta.",
}
else:
recall_result = evaluate_recall_batch(delta_str, recall_gold)
covered = recall_result.get("covered_count")
upd_covered = recall_result.get("update_covered_count")
total_gold = recall_result.get("total", len(recall_gold))
upd_total = recall_result.get("update_total", 0)
if recall_gold:
recall = float(covered) / float(total_gold) if covered is not None and total_gold else None
update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None
# --- Call 2: Correctness (this session's delta memories, reference = this session's golds) ---
correctness_gold = _build_correctness_gold_points(session)
correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total)
correctness_records = correctness_result.get("results", [])
num_correct = sum(1 for r in correctness_records if r.get("label") == "correct")
num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination")
num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant")
num_memories = len(delta_texts)
correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0
# --- Call 3+: Update handling (one LLM call per update gold point) ---
update_records: list[dict[str, object]] = []
for g in session.gold_state.session_update_memories:
res = evaluate_update_single(
delta_str,
new_content=g.memory_content,
old_contents=list(g.original_memories),
)
update_records.append({
"memory_id": g.memory_id,
"label": res["label"],
"reasoning": res["reasoning"],
})
num_updated = sum(1 for r in update_records if r["label"] == "updated")
num_both = sum(1 for r in update_records if r["label"] == "both")
num_outdated = sum(1 for r in update_records if r["label"] == "outdated")
update_total_items = len(update_records)
# Score: updated=1.0, both=0.5, outdated=0.0
update_score = (
(num_updated * 1.0 + num_both * 0.5) / update_total_items
if update_total_items else None
)
# --- Call 4+: Interference rejection (one LLM call per interference gold point) ---
interference_records: list[dict[str, object]] = []
for g in session.gold_state.session_interference_memories:
res = evaluate_interference_single(
delta_str,
interference_content=g.memory_content,
)
interference_records.append({
"memory_id": g.memory_id,
"label": res["label"],
"reasoning": res["reasoning"],
})
num_rejected = sum(1 for r in interference_records if r["label"] == "rejected")
num_memorized = sum(1 for r in interference_records if r["label"] == "memorized")
interference_total_items = len(interference_records)
# Score: rejected=1.0, memorized=0.0
interference_score = (
float(num_rejected) / interference_total_items
if interference_total_items else None
)
return {
"session_id": session.session_id,
"recall": recall,
"covered_count": covered,
"num_gold": total_gold,
"update_recall": update_recall,
"update_covered_count": upd_covered,
"update_total": upd_total,
"recall_reasoning": recall_result.get("reasoning", ""),
"correctness_rate": correctness_rate,
"num_memories": num_memories,
"num_correct": num_correct,
"num_hallucination": num_hallucination,
"num_irrelevant": num_irrelevant,
"correctness_reasoning": correctness_result.get("reasoning", ""),
"correctness_records": correctness_records,
# Update handling
"update_score": update_score,
"update_num_updated": num_updated,
"update_num_both": num_both,
"update_num_outdated": num_outdated,
"update_total_items": update_total_items,
"update_records": update_records,
# Interference rejection
"interference_score": interference_score,
"interference_num_rejected": num_rejected,
"interference_num_memorized": num_memorized,
"interference_total_items": interference_total_items,
"interference_records": interference_records,
}