| """Shared checkpoint QA: retrieval via adapter + answer from an injected callable. |
| |
| ``AnswerFn`` may return either a plain ``str`` (legacy) or a |
| ``(str, list[str])`` tuple of ``(answer, cited_memories)``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from typing import Union |
|
|
| from eval_framework.datasets.domain_a_v2 import NormalizedCheckpoint, NormalizedCheckpointQuestion |
| from eval_framework.datasets.schemas import RetrievalRecord |
| from eval_framework.memory_adapters.base import MemoryAdapter |
| from eval_framework.pipeline.records import PipelineCheckpointQARecord |
|
|
| |
| AnswerResult = Union[str, tuple[str, list[str]]] |
| AnswerFn = Callable[[NormalizedCheckpointQuestion, RetrievalRecord], AnswerResult] |
|
|
|
|
| def run_checkpoint_qa_records( |
| adapter: MemoryAdapter, |
| *, |
| sample_id: str, |
| sample_uuid: str, |
| checkpoint: NormalizedCheckpoint, |
| top_k: int, |
| answer_fn: AnswerFn, |
| ) -> tuple[PipelineCheckpointQARecord, ...]: |
| """For each question, call ``retrieve`` then ``answer_fn`` (not ``adapter.answer``).""" |
| out: list[PipelineCheckpointQARecord] = [] |
| for q in checkpoint.questions: |
| retrieval = adapter.retrieve(q.question, top_k) |
| result = answer_fn(q, retrieval) |
|
|
| if isinstance(result, tuple): |
| generated, cited = result |
| else: |
| generated, cited = result, [] |
|
|
| out.append( |
| PipelineCheckpointQARecord( |
| sample_id=sample_id, |
| sample_uuid=sample_uuid, |
| checkpoint_id=checkpoint.checkpoint_id, |
| question=q.question, |
| gold_answer=q.gold_answer, |
| gold_evidence_memory_ids=q.gold_evidence_memory_ids, |
| gold_evidence_contents=q.gold_evidence_contents, |
| question_type=q.question_type, |
| question_type_abbrev=q.question_type_abbrev, |
| difficulty=q.difficulty, |
| retrieval=retrieval, |
| generated_answer=generated, |
| cited_memories=tuple(cited), |
| ) |
| ) |
| return tuple(out) |
|
|