"""Domain A v2 academic bundle: dialogue normalization + staged QA / gold state.""" from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any, Iterator, Mapping from eval_framework.datasets.schemas import NormalizedTurn, normalize_turn from eval_framework.pipeline.gold_state import ( SessionGoldState, build_session_gold_states, ) @dataclass(frozen=True) class Stage4Record: uuid: str sample_id: str memory_sessions: tuple[tuple[str, tuple[Mapping[str, Any], ...]], ...] @dataclass(frozen=True) class QARecord: uuid: str sample_id: str raw_checkpoints: tuple[Mapping[str, Any], ...] @dataclass(frozen=True) class NormalizedCheckpointQuestion: question: str gold_answer: str gold_evidence_memory_ids: tuple[str, ...] gold_evidence_contents: tuple[str, ...] question_type: str question_type_abbrev: str difficulty: str @dataclass(frozen=True) class NormalizedCheckpoint: checkpoint_id: str covered_sessions: tuple[str, ...] questions: tuple[NormalizedCheckpointQuestion, ...] @dataclass(frozen=True) class DomainAV2Session: session_id: str turns: tuple[NormalizedTurn, ...] @dataclass(frozen=True) class DomainAV2AcademicSample: uuid: str sample_id: str sessions: tuple[DomainAV2Session, ...] stage4: Stage4Record qa_record: QARecord normalized_checkpoints: tuple[NormalizedCheckpoint, ...] session_gold_states: tuple[SessionGoldState, ...] @dataclass(frozen=True) class DomainAV2AcademicBundle: samples: tuple[DomainAV2AcademicSample, ...] def _read_jsonl(path: Path) -> Iterator[dict[str, Any]]: with path.open(encoding="utf-8") as fh: for line in fh: line = line.strip() if not line: continue yield json.loads(line) def _stage4_from_obj(obj: Mapping[str, Any]) -> Stage4Record: blocks: list[tuple[str, tuple[Mapping[str, Any], ...]]] = [] for ms in obj.get("memory_sessions") or []: sid = str(ms.get("session_id", "")) pts = ms.get("memory_points") or [] if not isinstance(pts, list): pts = [] blocks.append((sid, tuple(pts))) return Stage4Record( uuid=str(obj["uuid"]), sample_id=str(obj["sample_id"]), memory_sessions=tuple(blocks), ) def _qa_from_obj(obj: Mapping[str, Any]) -> QARecord: cps = obj.get("checkpoints") or [] if not isinstance(cps, list): cps = [] return QARecord( uuid=str(obj["uuid"]), sample_id=str(obj["sample_id"]), raw_checkpoints=tuple(cps), ) def _normalize_checkpoint_question( raw: Mapping[str, Any], memory_content_map: Mapping[str, str], ) -> NormalizedCheckpointQuestion: evidence = raw.get("evidence") or [] mem_ids: list[str] = [] mem_contents: list[str] = [] if isinstance(evidence, list): for item in evidence: if isinstance(item, dict) and "memory_id" in item: mid = str(item["memory_id"]) mem_ids.append(mid) content = memory_content_map.get(mid, "") if content: mem_contents.append(content) return NormalizedCheckpointQuestion( question=str(raw.get("question", "")), gold_answer=str(raw.get("answer", "")), gold_evidence_memory_ids=tuple(mem_ids), gold_evidence_contents=tuple(mem_contents), question_type=str(raw.get("question_type", "")), question_type_abbrev=str(raw.get("question_type_abbrev", "")), difficulty=str(raw.get("difficulty", "")), ) def _normalize_checkpoints( raw_checkpoints: tuple[Mapping[str, Any], ...], memory_content_map: Mapping[str, str], ) -> tuple[NormalizedCheckpoint, ...]: out: list[NormalizedCheckpoint] = [] for cp in raw_checkpoints: qs = cp.get("questions") or [] if not isinstance(qs, list): qs = [] covered = cp.get("covered_sessions") or [] if not isinstance(covered, list): covered = [] out.append( NormalizedCheckpoint( checkpoint_id=str(cp.get("checkpoint_id", "")), covered_sessions=tuple(str(x) for x in covered), questions=tuple( _normalize_checkpoint_question(q, memory_content_map) for q in qs if isinstance(q, Mapping) ), ) ) return tuple(out) def _dialogue_turns(sample_id: str, session_id: str, dialogue: list[Any]) -> tuple[NormalizedTurn, ...]: turns: list[NormalizedTurn] = [] for turn_index, entry in enumerate(dialogue): if not isinstance(entry, dict): continue text = str(entry.get("content", "")) attachments_raw = entry.get("attachments") or [] captions: list[str] = [] if isinstance(attachments_raw, list): for att in attachments_raw: if isinstance(att, dict): cap = att.get("caption", "") captions.append(cap if isinstance(cap, str) else str(cap)) if captions: text = text + "\n\n" + "\n".join(captions) ts = entry.get("timestamp") timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None) raw_turn = { "sample_id": sample_id, "session_id": session_id, "turn_index": turn_index, "role": str(entry.get("role", "user")), "text": text, "attachments": [], "timestamp": timestamp, } turns.append(normalize_turn(raw_turn)) return tuple(turns) def load_domain_a_v2_academic(data_dir: Path) -> DomainAV2AcademicBundle: data_dir = data_dir.resolve() main_path = data_dir / "domain_a_v2.json" stage4_path = data_dir / "stage4_memory_points.jsonl" qa_path = data_dir / "stage4b_qa_checkpoints.jsonl" raw_samples = json.loads(main_path.read_text(encoding="utf-8")) if not isinstance(raw_samples, list): raise ValueError("domain_a_v2.json must be a list") stage4_by_id: dict[str, Stage4Record] = {} for obj in _read_jsonl(stage4_path): rec = _stage4_from_obj(obj) stage4_by_id[rec.sample_id] = rec qa_by_id: dict[str, QARecord] = {} for obj in _read_jsonl(qa_path): rec = _qa_from_obj(obj) qa_by_id[rec.sample_id] = rec built: list[DomainAV2AcademicSample] = [] for item in raw_samples: if not isinstance(item, dict): continue sample_id = str(item["sample_id"]) uuid = str(item["uuid"]) stage4 = stage4_by_id.get(sample_id) qa = qa_by_id.get(sample_id) if stage4 is None or qa is None: raise KeyError(f"missing stage4 or QA row for sample_id={sample_id}") stage4_map = {sid: pts for sid, pts in stage4.memory_sessions} sessions_raw = item.get("sessions") or [] if not isinstance(sessions_raw, list): sessions_raw = [] session_blocks: list[DomainAV2Session] = [] ordered_ids: list[str] = [] s00_points: tuple[Mapping[str, Any], ...] = () for sess in sessions_raw: if not isinstance(sess, dict): continue sid = str(sess.get("_v2_session_id", "")) if not sid: continue ordered_ids.append(sid) dialogue = sess.get("dialogue") or [] if not isinstance(dialogue, list): dialogue = [] session_blocks.append( DomainAV2Session( session_id=sid, turns=_dialogue_turns(sample_id, sid, dialogue), ) ) if sid == "S00": mps = sess.get("memory_points") or [] if isinstance(mps, list): s00_points = tuple(mps) gold_states = build_session_gold_states( ordered_ids, s00_memory_points=s00_points, stage4_by_session_id=stage4_map, ) # Build memory_id -> memory_content map from all sources memory_content_map: dict[str, str] = {} for mp_raw in s00_points: if isinstance(mp_raw, Mapping): mid = mp_raw.get("memory_id") mc = mp_raw.get("memory_content") if mid is not None and mc is not None: memory_content_map[str(mid)] = str(mc) for _sid, pts in stage4.memory_sessions: for mp_raw in pts: if isinstance(mp_raw, Mapping): mid = mp_raw.get("memory_id") mc = mp_raw.get("memory_content") if mid is not None and mc is not None: memory_content_map[str(mid)] = str(mc) built.append( DomainAV2AcademicSample( uuid=uuid, sample_id=sample_id, sessions=tuple(session_blocks), stage4=stage4, qa_record=qa, normalized_checkpoints=_normalize_checkpoints( qa.raw_checkpoints, memory_content_map ), session_gold_states=gold_states, ) ) return DomainAV2AcademicBundle(samples=tuple(built))