| """Session-by-session ingest, memory export, and checkpoint QA orchestration.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
|
|
| from eval_framework.datasets.domain_a_v2 import ( |
| DomainAV2AcademicSample, |
| NormalizedCheckpointQuestion, |
| ) |
| from eval_framework.memory_adapters.base import MemoryAdapter |
| from eval_framework.pipeline.qa_runner import run_checkpoint_qa_records |
| from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord |
| from eval_framework.datasets.schemas import RetrievalRecord |
|
|
|
|
| def ensure_adapter_available(adapter: MemoryAdapter) -> None: |
| caps = adapter.get_capabilities() |
| if caps.get("available") is False: |
| backend = caps.get("backend", type(adapter).__name__) |
| detail = caps.get("integration_error") or caps.get( |
| "integration_status", "available=False" |
| ) |
| raise RuntimeError( |
| f"Memory adapter {backend!r} is not available for pipeline runs: {detail}" |
| ) |
|
|
|
|
| def run_domain_a_v2_sample( |
| adapter: MemoryAdapter, |
| sample: DomainAV2AcademicSample, |
| *, |
| top_k: int = 5, |
| answer_fn: Callable | None = None, |
| ) -> tuple[tuple[PipelineSessionRecord, ...], tuple[PipelineCheckpointQARecord, ...]]: |
| """Run all sessions in order, emit one session record per session, then checkpoint QA when due.""" |
| ensure_adapter_available(adapter) |
| if sample.normalized_checkpoints and answer_fn is None: |
| raise ValueError( |
| "answer_fn is required when the sample defines normalized checkpoints" |
| ) |
|
|
| adapter.reset() |
| session_out: list[PipelineSessionRecord] = [] |
| qa_out: list[PipelineCheckpointQARecord] = [] |
| completed_sessions: set[str] = set() |
| session_order = { |
| session.session_id: index for index, session in enumerate(sample.sessions) |
| } |
|
|
| if len(sample.sessions) != len(sample.session_gold_states): |
| raise ValueError( |
| "sample.sessions and sample.session_gold_states length mismatch" |
| ) |
|
|
| for sess, gold in zip(sample.sessions, sample.session_gold_states): |
| if sess.session_id != gold.session_id: |
| raise ValueError( |
| f"session / gold_state id mismatch: {sess.session_id!r} vs {gold.session_id!r}" |
| ) |
| for turn in sess.turns: |
| adapter.ingest_turn(turn) |
| adapter.end_session(sess.session_id) |
|
|
| snapshot = tuple(adapter.snapshot_memories()) |
| delta = tuple(adapter.export_memory_delta(sess.session_id)) |
| session_out.append( |
| PipelineSessionRecord( |
| sample_id=sample.sample_id, |
| sample_uuid=sample.uuid, |
| session_id=sess.session_id, |
| memory_snapshot=snapshot, |
| memory_delta=delta, |
| gold_state=gold, |
| ) |
| ) |
| completed_sessions.add(sess.session_id) |
|
|
| for cp in sample.normalized_checkpoints: |
| covered = cp.covered_sessions |
| if not covered: |
| continue |
| missing = [sid for sid in covered if sid not in session_order] |
| if missing: |
| raise ValueError( |
| f"checkpoint {cp.checkpoint_id!r} references unknown sessions: {missing}" |
| ) |
| if not set(covered).issubset(completed_sessions): |
| continue |
| trigger_session_id = max(covered, key=session_order.__getitem__) |
| if sess.session_id != trigger_session_id: |
| continue |
| qa_out.extend( |
| run_checkpoint_qa_records( |
| adapter, |
| sample_id=sample.sample_id, |
| sample_uuid=sample.uuid, |
| checkpoint=cp, |
| top_k=top_k, |
| answer_fn=answer_fn, |
| ) |
| ) |
|
|
| return tuple(session_out), tuple(qa_out) |
|
|