"""Shared checkpoint QA: retrieval via adapter + answer from an injected callable. ``AnswerFn`` may return either a plain ``str`` (legacy) or a ``(str, list[str])`` tuple of ``(answer, cited_memories)``. """ from __future__ import annotations from collections.abc import Callable from typing import Union from eval_framework.datasets.domain_a_v2 import NormalizedCheckpoint, NormalizedCheckpointQuestion from eval_framework.datasets.schemas import RetrievalRecord from eval_framework.memory_adapters.base import MemoryAdapter from eval_framework.pipeline.records import PipelineCheckpointQARecord # answer_fn may return str (legacy) or (str, list[str]) AnswerResult = Union[str, tuple[str, list[str]]] AnswerFn = Callable[[NormalizedCheckpointQuestion, RetrievalRecord], AnswerResult] def run_checkpoint_qa_records( adapter: MemoryAdapter, *, sample_id: str, sample_uuid: str, checkpoint: NormalizedCheckpoint, top_k: int, answer_fn: AnswerFn, ) -> tuple[PipelineCheckpointQARecord, ...]: """For each question, call ``retrieve`` then ``answer_fn`` (not ``adapter.answer``).""" out: list[PipelineCheckpointQARecord] = [] for q in checkpoint.questions: retrieval = adapter.retrieve(q.question, top_k) result = answer_fn(q, retrieval) if isinstance(result, tuple): generated, cited = result else: generated, cited = result, [] out.append( PipelineCheckpointQARecord( sample_id=sample_id, sample_uuid=sample_uuid, checkpoint_id=checkpoint.checkpoint_id, question=q.question, gold_answer=q.gold_answer, gold_evidence_memory_ids=q.gold_evidence_memory_ids, gold_evidence_contents=q.gold_evidence_contents, question_type=q.question_type, question_type_abbrev=q.question_type_abbrev, difficulty=q.difficulty, retrieval=retrieval, generated_answer=generated, cited_memories=tuple(cited), ) ) return tuple(out)