eval_framework / pipeline /runner.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Session-by-session ingest, memory export, and checkpoint QA orchestration."""
from __future__ import annotations
from collections.abc import Callable
from eval_framework.datasets.domain_a_v2 import (
DomainAV2AcademicSample,
NormalizedCheckpointQuestion,
)
from eval_framework.memory_adapters.base import MemoryAdapter
from eval_framework.pipeline.qa_runner import run_checkpoint_qa_records
from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
from eval_framework.datasets.schemas import RetrievalRecord
def ensure_adapter_available(adapter: MemoryAdapter) -> None:
caps = adapter.get_capabilities()
if caps.get("available") is False:
backend = caps.get("backend", type(adapter).__name__)
detail = caps.get("integration_error") or caps.get(
"integration_status", "available=False"
)
raise RuntimeError(
f"Memory adapter {backend!r} is not available for pipeline runs: {detail}"
)
def run_domain_a_v2_sample(
adapter: MemoryAdapter,
sample: DomainAV2AcademicSample,
*,
top_k: int = 5,
answer_fn: Callable | None = None,
) -> tuple[tuple[PipelineSessionRecord, ...], tuple[PipelineCheckpointQARecord, ...]]:
"""Run all sessions in order, emit one session record per session, then checkpoint QA when due."""
ensure_adapter_available(adapter)
if sample.normalized_checkpoints and answer_fn is None:
raise ValueError(
"answer_fn is required when the sample defines normalized checkpoints"
)
adapter.reset()
session_out: list[PipelineSessionRecord] = []
qa_out: list[PipelineCheckpointQARecord] = []
completed_sessions: set[str] = set()
session_order = {
session.session_id: index for index, session in enumerate(sample.sessions)
}
if len(sample.sessions) != len(sample.session_gold_states):
raise ValueError(
"sample.sessions and sample.session_gold_states length mismatch"
)
for sess, gold in zip(sample.sessions, sample.session_gold_states):
if sess.session_id != gold.session_id:
raise ValueError(
f"session / gold_state id mismatch: {sess.session_id!r} vs {gold.session_id!r}"
)
for turn in sess.turns:
adapter.ingest_turn(turn)
adapter.end_session(sess.session_id)
snapshot = tuple(adapter.snapshot_memories())
delta = tuple(adapter.export_memory_delta(sess.session_id))
session_out.append(
PipelineSessionRecord(
sample_id=sample.sample_id,
sample_uuid=sample.uuid,
session_id=sess.session_id,
memory_snapshot=snapshot,
memory_delta=delta,
gold_state=gold,
)
)
completed_sessions.add(sess.session_id)
for cp in sample.normalized_checkpoints:
covered = cp.covered_sessions
if not covered:
continue
missing = [sid for sid in covered if sid not in session_order]
if missing:
raise ValueError(
f"checkpoint {cp.checkpoint_id!r} references unknown sessions: {missing}"
)
if not set(covered).issubset(completed_sessions):
continue
trigger_session_id = max(covered, key=session_order.__getitem__)
if sess.session_id != trigger_session_id:
continue
qa_out.extend(
run_checkpoint_qa_records(
adapter,
sample_id=sample.sample_id,
sample_uuid=sample.uuid,
checkpoint=cp,
top_k=top_k,
answer_fn=answer_fn,
)
)
return tuple(session_out), tuple(qa_out)