File size: 8,007 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""Unified session evaluation: recall + correctness (includes update & interference).

Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only:
  Call 1 — Recall: how many of this session's gold points are covered by the
           session's memory delta (add/update ops)?
  Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant?
           (reference = this session's gold points + interference)

Aggregate: per-session recall/correctness averaged across sessions.
"""

from __future__ import annotations

from eval_framework.judges import (
    evaluate_correctness_batch,
    evaluate_interference_single,
    evaluate_recall_batch,
    evaluate_update_single,
)
from eval_framework.pipeline.records import PipelineSessionRecord


def _delta_to_text(session: PipelineSessionRecord) -> str:
    """Only the memories added or updated in THIS session (not the full snapshot)."""
    lines: list[str] = []
    idx = 0
    for d in session.memory_delta:
        if d.op in ("add", "update"):
            idx += 1
            lines.append(f"[{idx}] {d.text}")
    return "\n".join(lines)


def _delta_texts(session: PipelineSessionRecord) -> list[str]:
    """Text list of memories added or updated in THIS session."""
    return [d.text for d in session.memory_delta if d.op in ("add", "update")]


def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]:
    """Current session's new + update gold points only (NOT cumulative)."""
    out: list[str] = []
    for g in session.gold_state.session_new_memories:
        out.append(f"[normal] {g.memory_content}")
    for g in session.gold_state.session_update_memories:
        out.append(f"[update] {g.memory_content}")
    return out


def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]:
    """Current session's new + update + interference gold points as reference."""
    out: list[str] = []
    for g in session.gold_state.session_new_memories:
        out.append(f"[normal] {g.memory_content}")
    for g in session.gold_state.session_update_memories:
        out.append(f"[update] {g.memory_content}")
    for g in session.gold_state.session_interference_memories:
        out.append(f"[interference] {g.memory_content}")
    return out


def evaluate_extraction(
    session: PipelineSessionRecord,
    **_kwargs: object,
) -> dict[str, object]:
    """Unified session evaluation: recall + correctness in 2 LLM calls.

    Uses only THIS session's new gold points for recall and correctness,
    not the cumulative history. Aggregate averages per-session scores.
    """

    delta_str = _delta_to_text(session)
    delta_texts = _delta_texts(session)
    interference_total = len(session.gold_state.session_interference_memories)

    # --- Call 1: Recall (this session's gold points vs this session's delta) ---
    recall_gold = _build_recall_gold_points(session)

    if not recall_gold:
        recall = None
        update_recall = None
        recall_result: dict[str, object] = {
            "covered_count": 0, "update_covered_count": 0,
            "total": 0, "update_total": 0,
            "reasoning": "No new gold points in this session.",
        }
    elif not delta_str.strip():
        recall = 0.0
        update_recall = 0.0
        update_total = sum(1 for p in recall_gold if p.startswith("[update]"))
        recall_result = {
            "covered_count": 0, "update_covered_count": 0,
            "total": len(recall_gold), "update_total": update_total,
            "reasoning": "No add/update memories in this session's delta.",
        }
    else:
        recall_result = evaluate_recall_batch(delta_str, recall_gold)

    covered = recall_result.get("covered_count")
    upd_covered = recall_result.get("update_covered_count")
    total_gold = recall_result.get("total", len(recall_gold))
    upd_total = recall_result.get("update_total", 0)

    if recall_gold:
        recall = float(covered) / float(total_gold) if covered is not None and total_gold else None
        update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None

    # --- Call 2: Correctness (this session's delta memories, reference = this session's golds) ---
    correctness_gold = _build_correctness_gold_points(session)
    correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total)
    correctness_records = correctness_result.get("results", [])

    num_correct = sum(1 for r in correctness_records if r.get("label") == "correct")
    num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination")
    num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant")
    num_memories = len(delta_texts)
    correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0

    # --- Call 3+: Update handling (one LLM call per update gold point) ---
    update_records: list[dict[str, object]] = []
    for g in session.gold_state.session_update_memories:
        res = evaluate_update_single(
            delta_str,
            new_content=g.memory_content,
            old_contents=list(g.original_memories),
        )
        update_records.append({
            "memory_id": g.memory_id,
            "label": res["label"],
            "reasoning": res["reasoning"],
        })

    num_updated = sum(1 for r in update_records if r["label"] == "updated")
    num_both = sum(1 for r in update_records if r["label"] == "both")
    num_outdated = sum(1 for r in update_records if r["label"] == "outdated")
    update_total_items = len(update_records)
    # Score: updated=1.0, both=0.5, outdated=0.0
    update_score = (
        (num_updated * 1.0 + num_both * 0.5) / update_total_items
        if update_total_items else None
    )

    # --- Call 4+: Interference rejection (one LLM call per interference gold point) ---
    interference_records: list[dict[str, object]] = []
    for g in session.gold_state.session_interference_memories:
        res = evaluate_interference_single(
            delta_str,
            interference_content=g.memory_content,
        )
        interference_records.append({
            "memory_id": g.memory_id,
            "label": res["label"],
            "reasoning": res["reasoning"],
        })

    num_rejected = sum(1 for r in interference_records if r["label"] == "rejected")
    num_memorized = sum(1 for r in interference_records if r["label"] == "memorized")
    interference_total_items = len(interference_records)
    # Score: rejected=1.0, memorized=0.0
    interference_score = (
        float(num_rejected) / interference_total_items
        if interference_total_items else None
    )

    return {
        "session_id": session.session_id,
        "recall": recall,
        "covered_count": covered,
        "num_gold": total_gold,
        "update_recall": update_recall,
        "update_covered_count": upd_covered,
        "update_total": upd_total,
        "recall_reasoning": recall_result.get("reasoning", ""),
        "correctness_rate": correctness_rate,
        "num_memories": num_memories,
        "num_correct": num_correct,
        "num_hallucination": num_hallucination,
        "num_irrelevant": num_irrelevant,
        "correctness_reasoning": correctness_result.get("reasoning", ""),
        "correctness_records": correctness_records,
        # Update handling
        "update_score": update_score,
        "update_num_updated": num_updated,
        "update_num_both": num_both,
        "update_num_outdated": num_outdated,
        "update_total_items": update_total_items,
        "update_records": update_records,
        # Interference rejection
        "interference_score": interference_score,
        "interference_num_rejected": num_rejected,
        "interference_num_memorized": num_memorized,
        "interference_total_items": interference_total_items,
        "interference_records": interference_records,
    }