"""Session-by-session ingest, memory export, and checkpoint QA orchestration.""" from __future__ import annotations from collections.abc import Callable from eval_framework.datasets.domain_a_v2 import ( DomainAV2AcademicSample, NormalizedCheckpointQuestion, ) from eval_framework.memory_adapters.base import MemoryAdapter from eval_framework.pipeline.qa_runner import run_checkpoint_qa_records from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord from eval_framework.datasets.schemas import RetrievalRecord def ensure_adapter_available(adapter: MemoryAdapter) -> None: caps = adapter.get_capabilities() if caps.get("available") is False: backend = caps.get("backend", type(adapter).__name__) detail = caps.get("integration_error") or caps.get( "integration_status", "available=False" ) raise RuntimeError( f"Memory adapter {backend!r} is not available for pipeline runs: {detail}" ) def run_domain_a_v2_sample( adapter: MemoryAdapter, sample: DomainAV2AcademicSample, *, top_k: int = 5, answer_fn: Callable | None = None, ) -> tuple[tuple[PipelineSessionRecord, ...], tuple[PipelineCheckpointQARecord, ...]]: """Run all sessions in order, emit one session record per session, then checkpoint QA when due.""" ensure_adapter_available(adapter) if sample.normalized_checkpoints and answer_fn is None: raise ValueError( "answer_fn is required when the sample defines normalized checkpoints" ) adapter.reset() session_out: list[PipelineSessionRecord] = [] qa_out: list[PipelineCheckpointQARecord] = [] completed_sessions: set[str] = set() session_order = { session.session_id: index for index, session in enumerate(sample.sessions) } if len(sample.sessions) != len(sample.session_gold_states): raise ValueError( "sample.sessions and sample.session_gold_states length mismatch" ) for sess, gold in zip(sample.sessions, sample.session_gold_states): if sess.session_id != gold.session_id: raise ValueError( f"session / gold_state id mismatch: {sess.session_id!r} vs {gold.session_id!r}" ) for turn in sess.turns: adapter.ingest_turn(turn) adapter.end_session(sess.session_id) snapshot = tuple(adapter.snapshot_memories()) delta = tuple(adapter.export_memory_delta(sess.session_id)) session_out.append( PipelineSessionRecord( sample_id=sample.sample_id, sample_uuid=sample.uuid, session_id=sess.session_id, memory_snapshot=snapshot, memory_delta=delta, gold_state=gold, ) ) completed_sessions.add(sess.session_id) for cp in sample.normalized_checkpoints: covered = cp.covered_sessions if not covered: continue missing = [sid for sid in covered if sid not in session_order] if missing: raise ValueError( f"checkpoint {cp.checkpoint_id!r} references unknown sessions: {missing}" ) if not set(covered).issubset(completed_sessions): continue trigger_session_id = max(covered, key=session_order.__getitem__) if sess.session_id != trigger_session_id: continue qa_out.extend( run_checkpoint_qa_records( adapter, sample_id=sample.sample_id, sample_uuid=sample.uuid, checkpoint=cp, top_k=top_k, answer_fn=answer_fn, ) ) return tuple(session_out), tuple(qa_out)