"""Load question banks from JSON or JSONL files.""" from __future__ import annotations import json from pathlib import Path from typing import Any from .question import TSQuestion # Canonical domain keys used by EpisodeSampler (must match bank files or dataset field) DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers") def _parse_records(raw: Any) -> list[dict[str, Any]]: if isinstance(raw, list): return [x for x in raw if isinstance(x, dict)] if isinstance(raw, dict) and "questions" in raw: q = raw["questions"] if isinstance(q, list): return [x for x in q if isinstance(x, dict)] raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}") def _record_to_question(obj: dict[str, Any]) -> TSQuestion: return TSQuestion.model_validate(obj) def load_json_file(path: Path) -> list[TSQuestion]: """Load a single .json file (array or {\"questions\": [...]}).""" raw = json.loads(path.read_text(encoding="utf-8")) records = _parse_records(raw) return [_record_to_question(r) for r in records] def load_jsonl_file(path: Path) -> list[TSQuestion]: """Load newline-delimited JSON; each line must be a full TSQuestion object.""" out: list[TSQuestion] = [] for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): line = line.strip() if not line: continue try: obj = json.loads(line) except json.JSONDecodeError as e: raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e if not isinstance(obj, dict): raise ValueError(f"{path}:{line_no}: expected object per line") out.append(_record_to_question(obj)) return out def load_question_banks( bank_dir: Path | str | None, *, domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER, explicit_files: list[Path | str] | None = None, ) -> dict[str, list[TSQuestion]]: """ Load per-dataset question pools. If ``bank_dir`` is set, loads ``_questions.json`` for each domain in ``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in the directory that declare a ``dataset`` field per record (merged lists). If ``explicit_files`` is set, each file is loaded; records are grouped by ``dataset`` field (required for merged files). """ pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order} if explicit_files: for fp in explicit_files: path = Path(fp) items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path) for q in items: if q.dataset not in pools: pools[q.dataset] = [] pools[q.dataset].append(q) return pools if bank_dir is None: return pools root = Path(bank_dir) if not root.is_dir(): raise NotADirectoryError(f"question_bank_path must be a directory: {root}") # Per-dataset convention: PSML_questions.json etc. for domain in domain_order: candidates = [ root / f"{domain}_questions.json", root / f"{domain.lower()}_questions.json", ] for c in candidates: if c.is_file(): pools[domain].extend(load_json_file(c)) break # Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests) for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")): if path.name in ("manifest.json", "build_manifest.json"): continue if any(path.name == f"{d}_questions.json" for d in domain_order): continue if any(path.name == f"{d.lower()}_questions.json" for d in domain_order): continue items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path) for q in items: key = q.dataset if key not in pools: pools[key] = [] pools[key].append(q) return pools