File size: 14,794 Bytes
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
a493f04
31a2688
a493f04
31a2688
 
a493f04
31a2688
a493f04
 
 
 
 
 
31a2688
 
 
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
a493f04
31a2688
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
 
a493f04
 
 
 
 
31a2688
 
a493f04
31a2688
 
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
 
 
a493f04
31a2688
a493f04
 
31a2688
 
 
 
 
a493f04
 
31a2688
a493f04
 
31a2688
 
 
 
a493f04
31a2688
a493f04
31a2688
 
a493f04
31a2688
 
 
 
 
 
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
a493f04
 
 
 
 
 
31a2688
 
a493f04
 
 
 
 
31a2688
 
a493f04
 
 
31a2688
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
a493f04
 
 
 
 
 
 
31a2688
a493f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
a493f04
 
 
 
 
31a2688
 
a493f04
 
 
31a2688
 
a493f04
 
31a2688
 
a493f04
 
 
 
 
 
31a2688
a493f04
 
31a2688
 
 
 
a493f04
 
 
31a2688
a493f04
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""RAGAS-based evaluation for retrieval and generation quality.

Uses the legacy ``ragas.metrics`` classes (``LLMContextPrecisionWithReference``,
``LLMContextRecall``, ``Faithfulness``, ``AnswerRelevancy``,
``AnswerCorrectness``, ``FactualCorrectness``) rather than the newer
``ragas.metrics.collections`` API. The collections API only accepts
``InstructorLLM`` instances and would force us to import provider-specific
clients (openai / anthropic / ...) directly, which violates the project's
``provider.py``-only rule.

The legacy classes are wired with a ``LangchainLLMWrapper`` and a
``LangchainEmbeddingsWrapper`` at ``evaluate()`` time, so we keep using the
LangChain abstractions returned by ``src/provider.py`` everywhere.

Two metric families, two evaluation passes
------------------------------------------

For a multilingual test set (English questions querying Danish documents,
English-language answers, English+Danish reference fields), different metrics
want different reference languages:

- **Grounding metrics** (``ContextPrecision``, ``ContextRecall``) compare the
  reference against retrieved Danish chunks. They work best with a Danish
  reference (mono-lingual matching), so we feed them ``source_quote_da``.
  ``Faithfulness`` and ``AnswerRelevancy`` ignore the reference field but ride
  along in this pass to share the dataset.

- **Correctness metrics** (``AnswerCorrectness``, ``FactualCorrectness``)
  compare the generated English answer against the reference directly. They
  work best with an English reference, so we feed them ``reference_en``.

``evaluate()`` builds two ``EvaluationDataset`` instances (same questions,
contexts and answers, different ``reference`` field per pass) and runs RAGAS
twice. The aggregate scores are unioned and per-sample rows are merged on
``user_input``.

Why both families
-----------------

``Faithfulness`` measures whether each claim in the answer is supported by the
retrieved chunks. It does **not** check whether those chunks (and therefore
the answer) are actually correct. An answer that confidently quotes the wrong
chunks scores 1.0; an answer that hedges with prior-knowledge inference but
matches the ground truth scores low. Adding ``AnswerCorrectness`` /
``FactualCorrectness`` (which compare directly to ``reference_en``) closes
this gap by measuring **whether the answer is right**, not just whether it
is grounded in whatever was retrieved.

Note: ``AnswerRelevancy`` is constructed with ``strictness=1``. The default of
3 issues an OpenAI-style ``n=3`` request to generate three hypothetical
questions in one API call, which Groq's API rejects with HTTP 400
``'n' : number must be at most 1``. With ``strictness=1`` the metric makes a
single call per sample, which all providers support.
"""

import logging
from typing import Any

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from ragas import EvaluationDataset, SingleTurnSample, evaluate
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import LangchainLLMWrapper
from ragas.metrics._answer_correctness import AnswerCorrectness
from ragas.metrics._answer_relevance import AnswerRelevancy
from ragas.metrics._context_precision import LLMContextPrecisionWithReference
from ragas.metrics._context_recall import LLMContextRecall
from ragas.metrics._factual_correctness import FactualCorrectness
from ragas.metrics._faithfulness import Faithfulness

logger = logging.getLogger(__name__)

# Columns produced by RAGAS dataframes that are NOT metric scores. Used to
# separate metric columns from sample fields when computing aggregates.
_NON_METRIC_COLS: frozenset[str] = frozenset(
    {
        "user_input",
        "retrieved_contexts",
        "reference",
        "response",
        "ground_truth",
        "question",
        "answer",
        "contexts",
    }
)

# Reference-key fallback chains used by ``_resolve_reference``. Each chain
# starts with the preferred key for a particular metric family, then falls
# back to whichever other key is available so legacy plain-string ground
# truths still work.
_GROUNDING_REF_CHAIN: tuple[str, ...] = ("source_quote_da", "reference_en", "reference")
_CORRECTNESS_REF_CHAIN: tuple[str, ...] = ("reference_en", "source_quote_da", "reference")


class RAGEvaluator:
    """Evaluates RAG pipeline quality using two complementary RAGAS metric families.

    The judge LLM is independent from the generation LLM. This is critical
    when generation runs on a small local model: a stronger judge gives
    substantially less noisy scores.

    Each ground-truth entry may be either a plain string or a dict. The dict
    form is used by the multilingual test set::

        {
            "reference_en":    "English reference answer (informational)",
            "source_quote_da": "Verbatim Danish quote from the source document"
        }

    See the module docstring for why two metric families and two evaluation
    passes are used.
    """

    def __init__(self, llm: BaseChatModel, embeddings: Embeddings) -> None:
        """Initialize the evaluator.

        Args:
            llm: A LangChain BaseChatModel instance to use as the RAGAS judge.
                Should be a strong model (>= ~30B params) for reliable scoring.
            embeddings: A LangChain Embeddings instance. Required because
                ``AnswerRelevancy`` and ``AnswerCorrectness`` compute cosine
                similarity between text pairs.
        """
        self._llm = LangchainLLMWrapper(llm)
        self._embeddings = LangchainEmbeddingsWrapper(embeddings)
        logger.info("RAGEvaluator initialized")

    @staticmethod
    def _resolve_reference(
        ground_truth: str | dict[str, Any],
        ref_chain: tuple[str, ...],
    ) -> str:
        """Pick the best reference string for a given metric family.

        Args:
            ground_truth: Either a plain reference string or a dict with
                ``source_quote_da`` / ``reference_en`` / ``reference`` keys.
            ref_chain: Ordered tuple of dict keys to try, most preferred first.

        Returns:
            The reference string to feed into RAGAS.
        """
        if isinstance(ground_truth, str):
            return ground_truth
        if isinstance(ground_truth, dict):
            for key in ref_chain:
                value = ground_truth.get(key)
                if isinstance(value, str) and value.strip():
                    return value
        return str(ground_truth)

    def _build_dataset(
        self,
        questions: list[str],
        contexts: list[list[str]],
        ground_truths: list[str | dict[str, Any]],
        answers: list[str] | None = None,
        *,
        ref_chain: tuple[str, ...] = _GROUNDING_REF_CHAIN,
    ) -> EvaluationDataset:
        """Build a RAGAS EvaluationDataset from raw inputs.

        Args:
            questions: List of input questions.
            contexts: Retrieved context lists per question.
            ground_truths: Reference answers (str or dict).
            answers: Optional list of generated answers.
            ref_chain: Reference-key fallback chain to use when resolving
                each ground-truth entry.

        Returns:
            EvaluationDataset ready for evaluation.
        """
        samples: list[SingleTurnSample] = []
        for i, question in enumerate(questions):
            sample_kwargs: dict[str, Any] = {
                "user_input": question,
                "retrieved_contexts": contexts[i],
                "reference": self._resolve_reference(ground_truths[i], ref_chain),
            }
            if answers is not None:
                sample_kwargs["response"] = answers[i]
            samples.append(SingleTurnSample(**sample_kwargs))
        return EvaluationDataset(samples=samples)

    @staticmethod
    def _result_to_dicts(
        result: Any,
    ) -> tuple[dict[str, float], list[dict[str, Any]]]:
        """Convert a RAGAS EvaluationResult into aggregate scores + per-sample rows.

        Uses the public ``to_pandas()`` API instead of the private
        ``_repr_dict`` so the code does not break across RAGAS minor versions.

        Args:
            result: RAGAS EvaluationResult instance.

        Returns:
            Tuple of (aggregate metric → mean score, list of per-sample row dicts).
        """
        df = result.to_pandas()
        metric_cols = [c for c in df.columns if c not in _NON_METRIC_COLS]
        aggregate: dict[str, float] = {}
        for col in metric_cols:
            if df[col].dtype.kind in "fi":
                aggregate[col] = float(df[col].mean())
        per_sample: list[dict[str, Any]] = df.to_dict(orient="records")
        return aggregate, per_sample

    @staticmethod
    def _merge_per_sample(
        rows_a: list[dict[str, Any]],
        rows_b: list[dict[str, Any]],
    ) -> list[dict[str, Any]]:
        """Merge two per-sample row lists by ``user_input``.

        Both lists are produced by ``_result_to_dicts`` and contain the same
        questions but different metric columns. The reference field will
        differ between the two passes (Danish vs English), so we keep the
        Danish one (from rows_a) as the canonical ``reference`` and add an
        ``reference_en`` column from rows_b for transparency.

        Args:
            rows_a: Rows from the grounding pass (reference = Danish quote).
            rows_b: Rows from the correctness pass (reference = English answer).

        Returns:
            Merged list of row dicts with all metric columns from both passes.
        """
        rows_b_by_q: dict[str, dict[str, Any]] = {row["user_input"]: row for row in rows_b}
        merged: list[dict[str, Any]] = []
        for row_a in rows_a:
            row_b = rows_b_by_q.get(row_a["user_input"], {})
            merged_row: dict[str, Any] = dict(row_a)
            # Preserve the English reference for transparency.
            if "reference" in row_b:
                merged_row["reference_en"] = row_b["reference"]
            # Add metric columns from pass B that are not already present.
            for key, value in row_b.items():
                if key in _NON_METRIC_COLS:
                    continue
                if key not in merged_row:
                    merged_row[key] = value
            merged.append(merged_row)
        return merged

    def evaluate(
        self,
        questions: list[str],
        answers: list[str],
        contexts: list[list[str]],
        ground_truths: list[str | dict[str, Any]],
    ) -> dict[str, Any]:
        """Run full RAGAS evaluation across grounding + correctness metric families.

        Performs two passes against the same questions / answers / contexts
        but with different reference languages (see module docstring).

        Args:
            questions: Input questions.
            answers: Generated answers from the RAG pipeline.
            contexts: Retrieved context lists per question.
            ground_truths: Reference answers (str or dict per
                ``_resolve_reference``).

        Returns:
            Dict with two keys:
                ``aggregate``:  metric name → mean score across all samples
                ``per_sample``: list of per-sample dicts (one row per question)
        """
        n = len(questions)
        logger.info("Running full evaluation on %d samples (two passes)", n)

        # ---- Pass 1: grounding metrics with Danish reference ----
        dataset_da = self._build_dataset(
            questions,
            contexts,
            ground_truths,
            answers,
            ref_chain=_GROUNDING_REF_CHAIN,
        )
        grounding_metrics = [
            Faithfulness(),
            AnswerRelevancy(strictness=1),
            LLMContextPrecisionWithReference(),
            LLMContextRecall(),
        ]
        logger.info("Pass 1/2: grounding metrics (reference = Danish quote)")
        result_a = evaluate(
            dataset=dataset_da,
            metrics=grounding_metrics,
            llm=self._llm,
            embeddings=self._embeddings,
            show_progress=False,
        )
        agg_a, samples_a = self._result_to_dicts(result_a)
        logger.info("Pass 1/2 aggregate: %s", agg_a)

        # ---- Pass 2: correctness metrics with English reference ----
        dataset_en = self._build_dataset(
            questions,
            contexts,
            ground_truths,
            answers,
            ref_chain=_CORRECTNESS_REF_CHAIN,
        )
        correctness_metrics = [
            AnswerCorrectness(),
            FactualCorrectness(),
        ]
        logger.info("Pass 2/2: correctness metrics (reference = English answer)")
        result_b = evaluate(
            dataset=dataset_en,
            metrics=correctness_metrics,
            llm=self._llm,
            embeddings=self._embeddings,
            show_progress=False,
        )
        agg_b, samples_b = self._result_to_dicts(result_b)
        logger.info("Pass 2/2 aggregate: %s", agg_b)

        # ---- Merge ----
        aggregate: dict[str, float] = {**agg_a, **agg_b}
        per_sample = self._merge_per_sample(samples_a, samples_b)
        logger.info("Combined aggregate: %s", aggregate)
        return {"aggregate": aggregate, "per_sample": per_sample}

    def evaluate_retrieval(
        self,
        questions: list[str],
        contexts: list[list[str]],
        ground_truths: list[str | dict[str, Any]],
    ) -> dict[str, Any]:
        """Evaluate retrieval quality only (ContextPrecision + ContextRecall).

        Single-pass against the Danish reference, no correctness metrics.

        Args:
            questions: Input questions.
            contexts: Retrieved context lists per question.
            ground_truths: Reference answers (str or dict).

        Returns:
            Dict with ``aggregate`` and ``per_sample`` keys, same shape as
            ``evaluate()`` but only retrieval metrics.
        """
        logger.info("Running retrieval evaluation on %d samples", len(questions))
        dataset = self._build_dataset(
            questions,
            contexts,
            ground_truths,
            ref_chain=_GROUNDING_REF_CHAIN,
        )
        metrics = [
            LLMContextPrecisionWithReference(),
            LLMContextRecall(),
        ]
        result = evaluate(
            dataset=dataset,
            metrics=metrics,
            llm=self._llm,
            embeddings=self._embeddings,
            show_progress=False,
        )
        aggregate, per_sample = self._result_to_dicts(result)
        logger.info("Retrieval evaluation aggregate: %s", aggregate)
        return {"aggregate": aggregate, "per_sample": per_sample}