File size: 3,734 Bytes
65dfa4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Evaluation runner — runs the RAG pipeline on an eval dataset and scores it.

Produces per-example and aggregate metrics for both retrieval and generation.
"""

import logging
from dataclasses import dataclass, field

from src.evaluation.dataset import EvalDataset, EvalExample
from src.evaluation.metrics import evaluate_generation, evaluate_retrieval
from src.generation.rag_engine import RAGEngine

logger = logging.getLogger(__name__)


@dataclass
class ExampleResult:
    """Evaluation results for a single example."""

    query: str
    retrieval_metrics: dict
    generation_metrics: dict
    retrieved_paper_ids: list[str]
    answer: str


@dataclass
class EvalReport:
    """Aggregate evaluation report."""

    dataset_name: str
    num_examples: int
    example_results: list[ExampleResult]
    aggregate_retrieval: dict = field(default_factory=dict)
    aggregate_generation: dict = field(default_factory=dict)


class EvalRunner:
    """Runs evaluation on a dataset using the RAG engine."""

    def __init__(self, rag_engine: RAGEngine, top_k: int = 5):
        self.engine = rag_engine
        self.top_k = top_k

    def evaluate_example(self, example: EvalExample) -> ExampleResult:
        """Evaluate a single example."""
        response = self.engine.query(question=example.query, top_k=self.top_k)

        retrieved_paper_ids = [s["paper_id"] for s in response.sources]
        relevant_ids = set(example.relevant_paper_ids)

        retrieval_metrics = evaluate_retrieval(
            retrieved_ids=retrieved_paper_ids,
            relevant_ids=relevant_ids,
        )

        generation_metrics = evaluate_generation(
            prediction=response.answer,
            reference=example.reference_answer,
            keywords=example.keywords if example.keywords else None,
        )

        return ExampleResult(
            query=example.query,
            retrieval_metrics=retrieval_metrics,
            generation_metrics=generation_metrics,
            retrieved_paper_ids=retrieved_paper_ids,
            answer=response.answer,
        )

    def run(self, dataset: EvalDataset) -> EvalReport:
        """Run evaluation on the full dataset."""
        logger.info("Running evaluation on %r (%d examples)", dataset.name, len(dataset))

        results = []
        for i, example in enumerate(dataset):
            logger.info("Evaluating %d/%d: %r", i + 1, len(dataset), example.query)
            result = self.evaluate_example(example)
            results.append(result)

        report = EvalReport(
            dataset_name=dataset.name,
            num_examples=len(results),
            example_results=results,
        )

        # Aggregate retrieval metrics
        if results:
            all_ret_keys = results[0].retrieval_metrics.keys()
            report.aggregate_retrieval = {
                key: sum(r.retrieval_metrics[key] for r in results) / len(results)
                for key in all_ret_keys
            }

            # Collect union of all generation metric keys (some examples may
            # have keywords and others may not, so keys can differ).
            all_gen_keys: set[str] = set()
            for r in results:
                all_gen_keys.update(r.generation_metrics.keys())
            report.aggregate_generation = {
                key: sum(r.generation_metrics.get(key, 0.0) for r in results) / len(results)
                for key in sorted(all_gen_keys)
            }

        logger.info(
            "Evaluation complete — retrieval: %s, generation: %s",
            {k: round(v, 3) for k, v in report.aggregate_retrieval.items()},
            {k: round(v, 3) for k, v in report.aggregate_generation.items()},
        )
        return report