Spaces:
Sleeping
Sleeping
| """Load question banks from JSON or JSONL files.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| from .question import TSQuestion | |
| # Canonical domain keys used by EpisodeSampler (must match bank files or dataset field) | |
| DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers") | |
| def _parse_records(raw: Any) -> list[dict[str, Any]]: | |
| if isinstance(raw, list): | |
| return [x for x in raw if isinstance(x, dict)] | |
| if isinstance(raw, dict) and "questions" in raw: | |
| q = raw["questions"] | |
| if isinstance(q, list): | |
| return [x for x in q if isinstance(x, dict)] | |
| raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}") | |
| def _record_to_question(obj: dict[str, Any]) -> TSQuestion: | |
| return TSQuestion.model_validate(obj) | |
| def load_json_file(path: Path) -> list[TSQuestion]: | |
| """Load a single .json file (array or {\"questions\": [...]}).""" | |
| raw = json.loads(path.read_text(encoding="utf-8")) | |
| records = _parse_records(raw) | |
| return [_record_to_question(r) for r in records] | |
| def load_jsonl_file(path: Path) -> list[TSQuestion]: | |
| """Load newline-delimited JSON; each line must be a full TSQuestion object.""" | |
| out: list[TSQuestion] = [] | |
| for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e | |
| if not isinstance(obj, dict): | |
| raise ValueError(f"{path}:{line_no}: expected object per line") | |
| out.append(_record_to_question(obj)) | |
| return out | |
| def load_question_banks( | |
| bank_dir: Path | str | None, | |
| *, | |
| domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER, | |
| explicit_files: list[Path | str] | None = None, | |
| ) -> dict[str, list[TSQuestion]]: | |
| """ | |
| Load per-dataset question pools. | |
| If ``bank_dir`` is set, loads ``<Dataset>_questions.json`` for each domain in | |
| ``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in | |
| the directory that declare a ``dataset`` field per record (merged lists). | |
| If ``explicit_files`` is set, each file is loaded; records are grouped by | |
| ``dataset`` field (required for merged files). | |
| """ | |
| pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order} | |
| if explicit_files: | |
| for fp in explicit_files: | |
| path = Path(fp) | |
| items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path) | |
| for q in items: | |
| if q.dataset not in pools: | |
| pools[q.dataset] = [] | |
| pools[q.dataset].append(q) | |
| return pools | |
| if bank_dir is None: | |
| return pools | |
| root = Path(bank_dir) | |
| if not root.is_dir(): | |
| raise NotADirectoryError(f"question_bank_path must be a directory: {root}") | |
| # Per-dataset convention: PSML_questions.json etc. | |
| for domain in domain_order: | |
| candidates = [ | |
| root / f"{domain}_questions.json", | |
| root / f"{domain.lower()}_questions.json", | |
| ] | |
| for c in candidates: | |
| if c.is_file(): | |
| pools[domain].extend(load_json_file(c)) | |
| break | |
| # Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests) | |
| for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")): | |
| if path.name in ("manifest.json", "build_manifest.json"): | |
| continue | |
| if any(path.name == f"{d}_questions.json" for d in domain_order): | |
| continue | |
| if any(path.name == f"{d.lower()}_questions.json" for d in domain_order): | |
| continue | |
| items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path) | |
| for q in items: | |
| key = q.dataset | |
| if key not in pools: | |
| pools[key] = [] | |
| pools[key].append(q) | |
| return pools | |