File size: 2,776 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Checkpoint QA evaluation: answer quality + batch evidence coverage.

Two dimensions:
1. Answer evaluation: Correct / Hallucination / Omission (1 LLM call)
2. Evidence coverage: how many gold evidence points are covered by the
   memories the model actually *cited* when answering? (1 LLM call)
"""

from __future__ import annotations

from eval_framework.judges import evaluate_evidence_batch, evaluate_qa_llm
from eval_framework.pipeline.records import PipelineCheckpointQARecord


def evaluate_checkpoint_qa(
    record: PipelineCheckpointQARecord,
    **_kwargs: object,
) -> dict[str, object]:
    """LLM-judged QA evaluation: answer correctness + evidence coverage."""

    # --- Build cited-memories text (what the model actually used) ---
    if record.cited_memories:
        cited_lines = [f"[{i + 1}] {m}" for i, m in enumerate(record.cited_memories)]
        cited_str = "\n".join(cited_lines)
    else:
        # Fallback: use full retrieval (legacy records without cited_memories)
        cited_lines = [f"[{item.rank}] {item.text}" for item in record.retrieval.items]
        cited_str = "\n".join(cited_lines) if cited_lines else ""

    # --- Answer evaluation (1 LLM call, unchanged) ---
    gold_evidence_str = (
        "\n".join(record.gold_evidence_contents)
        if record.gold_evidence_contents
        else "No evidence available."
    )
    answer_result = evaluate_qa_llm(
        question=record.question,
        reference_answer=record.gold_answer,
        key_memory_points=gold_evidence_str,
        system_response=record.generated_answer,
    )
    answer_label = answer_result.get("evaluation_result")

    # --- Evidence coverage (1 LLM call, batch) ---
    # Only check against cited memories, not the full retrieval
    gold_contents = list(record.gold_evidence_contents)
    evidence_result: dict[str, object] = {
        "covered_count": 0, "total": len(gold_contents), "reasoning": ""
    }

    if gold_contents and cited_str.strip():
        evidence_result = evaluate_evidence_batch(cited_str, gold_contents)

    covered = evidence_result.get("covered_count")
    total_ev = evidence_result.get("total", len(gold_contents))
    if covered is not None and total_ev:
        evidence_hit_rate = float(covered) / float(total_ev)
    else:
        evidence_hit_rate = 0.0

    return {
        "answer_label": answer_label,
        "answer_reasoning": answer_result.get("reasoning", ""),
        "answer_is_valid": answer_label in ("Correct", "Hallucination", "Omission"),
        "evidence_hit_rate": evidence_hit_rate,
        "evidence_covered_count": covered,
        "num_evidence": total_ev,
        "evidence_reasoning": evidence_result.get("reasoning", ""),
        "num_cited_memories": len(record.cited_memories),
    }