eval_framework / datasets /domain_a_v2.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Domain A v2 academic bundle: dialogue normalization + staged QA / gold state."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterator, Mapping
from eval_framework.datasets.schemas import NormalizedTurn, normalize_turn
from eval_framework.pipeline.gold_state import (
SessionGoldState,
build_session_gold_states,
)
@dataclass(frozen=True)
class Stage4Record:
uuid: str
sample_id: str
memory_sessions: tuple[tuple[str, tuple[Mapping[str, Any], ...]], ...]
@dataclass(frozen=True)
class QARecord:
uuid: str
sample_id: str
raw_checkpoints: tuple[Mapping[str, Any], ...]
@dataclass(frozen=True)
class NormalizedCheckpointQuestion:
question: str
gold_answer: str
gold_evidence_memory_ids: tuple[str, ...]
gold_evidence_contents: tuple[str, ...]
question_type: str
question_type_abbrev: str
difficulty: str
@dataclass(frozen=True)
class NormalizedCheckpoint:
checkpoint_id: str
covered_sessions: tuple[str, ...]
questions: tuple[NormalizedCheckpointQuestion, ...]
@dataclass(frozen=True)
class DomainAV2Session:
session_id: str
turns: tuple[NormalizedTurn, ...]
@dataclass(frozen=True)
class DomainAV2AcademicSample:
uuid: str
sample_id: str
sessions: tuple[DomainAV2Session, ...]
stage4: Stage4Record
qa_record: QARecord
normalized_checkpoints: tuple[NormalizedCheckpoint, ...]
session_gold_states: tuple[SessionGoldState, ...]
@dataclass(frozen=True)
class DomainAV2AcademicBundle:
samples: tuple[DomainAV2AcademicSample, ...]
def _read_jsonl(path: Path) -> Iterator[dict[str, Any]]:
with path.open(encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _stage4_from_obj(obj: Mapping[str, Any]) -> Stage4Record:
blocks: list[tuple[str, tuple[Mapping[str, Any], ...]]] = []
for ms in obj.get("memory_sessions") or []:
sid = str(ms.get("session_id", ""))
pts = ms.get("memory_points") or []
if not isinstance(pts, list):
pts = []
blocks.append((sid, tuple(pts)))
return Stage4Record(
uuid=str(obj["uuid"]),
sample_id=str(obj["sample_id"]),
memory_sessions=tuple(blocks),
)
def _qa_from_obj(obj: Mapping[str, Any]) -> QARecord:
cps = obj.get("checkpoints") or []
if not isinstance(cps, list):
cps = []
return QARecord(
uuid=str(obj["uuid"]),
sample_id=str(obj["sample_id"]),
raw_checkpoints=tuple(cps),
)
def _normalize_checkpoint_question(
raw: Mapping[str, Any],
memory_content_map: Mapping[str, str],
) -> NormalizedCheckpointQuestion:
evidence = raw.get("evidence") or []
mem_ids: list[str] = []
mem_contents: list[str] = []
if isinstance(evidence, list):
for item in evidence:
if isinstance(item, dict) and "memory_id" in item:
mid = str(item["memory_id"])
mem_ids.append(mid)
content = memory_content_map.get(mid, "")
if content:
mem_contents.append(content)
return NormalizedCheckpointQuestion(
question=str(raw.get("question", "")),
gold_answer=str(raw.get("answer", "")),
gold_evidence_memory_ids=tuple(mem_ids),
gold_evidence_contents=tuple(mem_contents),
question_type=str(raw.get("question_type", "")),
question_type_abbrev=str(raw.get("question_type_abbrev", "")),
difficulty=str(raw.get("difficulty", "")),
)
def _normalize_checkpoints(
raw_checkpoints: tuple[Mapping[str, Any], ...],
memory_content_map: Mapping[str, str],
) -> tuple[NormalizedCheckpoint, ...]:
out: list[NormalizedCheckpoint] = []
for cp in raw_checkpoints:
qs = cp.get("questions") or []
if not isinstance(qs, list):
qs = []
covered = cp.get("covered_sessions") or []
if not isinstance(covered, list):
covered = []
out.append(
NormalizedCheckpoint(
checkpoint_id=str(cp.get("checkpoint_id", "")),
covered_sessions=tuple(str(x) for x in covered),
questions=tuple(
_normalize_checkpoint_question(q, memory_content_map)
for q in qs
if isinstance(q, Mapping)
),
)
)
return tuple(out)
def _dialogue_turns(sample_id: str, session_id: str, dialogue: list[Any]) -> tuple[NormalizedTurn, ...]:
turns: list[NormalizedTurn] = []
for turn_index, entry in enumerate(dialogue):
if not isinstance(entry, dict):
continue
text = str(entry.get("content", ""))
attachments_raw = entry.get("attachments") or []
captions: list[str] = []
if isinstance(attachments_raw, list):
for att in attachments_raw:
if isinstance(att, dict):
cap = att.get("caption", "")
captions.append(cap if isinstance(cap, str) else str(cap))
if captions:
text = text + "\n\n" + "\n".join(captions)
ts = entry.get("timestamp")
timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None)
raw_turn = {
"sample_id": sample_id,
"session_id": session_id,
"turn_index": turn_index,
"role": str(entry.get("role", "user")),
"text": text,
"attachments": [],
"timestamp": timestamp,
}
turns.append(normalize_turn(raw_turn))
return tuple(turns)
def load_domain_a_v2_academic(data_dir: Path) -> DomainAV2AcademicBundle:
data_dir = data_dir.resolve()
main_path = data_dir / "domain_a_v2.json"
stage4_path = data_dir / "stage4_memory_points.jsonl"
qa_path = data_dir / "stage4b_qa_checkpoints.jsonl"
raw_samples = json.loads(main_path.read_text(encoding="utf-8"))
if not isinstance(raw_samples, list):
raise ValueError("domain_a_v2.json must be a list")
stage4_by_id: dict[str, Stage4Record] = {}
for obj in _read_jsonl(stage4_path):
rec = _stage4_from_obj(obj)
stage4_by_id[rec.sample_id] = rec
qa_by_id: dict[str, QARecord] = {}
for obj in _read_jsonl(qa_path):
rec = _qa_from_obj(obj)
qa_by_id[rec.sample_id] = rec
built: list[DomainAV2AcademicSample] = []
for item in raw_samples:
if not isinstance(item, dict):
continue
sample_id = str(item["sample_id"])
uuid = str(item["uuid"])
stage4 = stage4_by_id.get(sample_id)
qa = qa_by_id.get(sample_id)
if stage4 is None or qa is None:
raise KeyError(f"missing stage4 or QA row for sample_id={sample_id}")
stage4_map = {sid: pts for sid, pts in stage4.memory_sessions}
sessions_raw = item.get("sessions") or []
if not isinstance(sessions_raw, list):
sessions_raw = []
session_blocks: list[DomainAV2Session] = []
ordered_ids: list[str] = []
s00_points: tuple[Mapping[str, Any], ...] = ()
for sess in sessions_raw:
if not isinstance(sess, dict):
continue
sid = str(sess.get("_v2_session_id", ""))
if not sid:
continue
ordered_ids.append(sid)
dialogue = sess.get("dialogue") or []
if not isinstance(dialogue, list):
dialogue = []
session_blocks.append(
DomainAV2Session(
session_id=sid,
turns=_dialogue_turns(sample_id, sid, dialogue),
)
)
if sid == "S00":
mps = sess.get("memory_points") or []
if isinstance(mps, list):
s00_points = tuple(mps)
gold_states = build_session_gold_states(
ordered_ids,
s00_memory_points=s00_points,
stage4_by_session_id=stage4_map,
)
# Build memory_id -> memory_content map from all sources
memory_content_map: dict[str, str] = {}
for mp_raw in s00_points:
if isinstance(mp_raw, Mapping):
mid = mp_raw.get("memory_id")
mc = mp_raw.get("memory_content")
if mid is not None and mc is not None:
memory_content_map[str(mid)] = str(mc)
for _sid, pts in stage4.memory_sessions:
for mp_raw in pts:
if isinstance(mp_raw, Mapping):
mid = mp_raw.get("memory_id")
mc = mp_raw.get("memory_content")
if mid is not None and mc is not None:
memory_content_map[str(mid)] = str(mc)
built.append(
DomainAV2AcademicSample(
uuid=uuid,
sample_id=sample_id,
sessions=tuple(session_blocks),
stage4=stage4,
qa_record=qa,
normalized_checkpoints=_normalize_checkpoints(
qa.raw_checkpoints, memory_content_map
),
session_gold_states=gold_states,
)
)
return DomainAV2AcademicBundle(samples=tuple(built))