TemporalBenchEnv / data /loaders.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
"""Load question banks from JSON or JSONL files."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from .question import TSQuestion
# Canonical domain keys used by EpisodeSampler (must match bank files or dataset field)
DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
def _parse_records(raw: Any) -> list[dict[str, Any]]:
if isinstance(raw, list):
return [x for x in raw if isinstance(x, dict)]
if isinstance(raw, dict) and "questions" in raw:
q = raw["questions"]
if isinstance(q, list):
return [x for x in q if isinstance(x, dict)]
raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}")
def _record_to_question(obj: dict[str, Any]) -> TSQuestion:
return TSQuestion.model_validate(obj)
def load_json_file(path: Path) -> list[TSQuestion]:
"""Load a single .json file (array or {\"questions\": [...]})."""
raw = json.loads(path.read_text(encoding="utf-8"))
records = _parse_records(raw)
return [_record_to_question(r) for r in records]
def load_jsonl_file(path: Path) -> list[TSQuestion]:
"""Load newline-delimited JSON; each line must be a full TSQuestion object."""
out: list[TSQuestion] = []
for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e
if not isinstance(obj, dict):
raise ValueError(f"{path}:{line_no}: expected object per line")
out.append(_record_to_question(obj))
return out
def load_question_banks(
bank_dir: Path | str | None,
*,
domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER,
explicit_files: list[Path | str] | None = None,
) -> dict[str, list[TSQuestion]]:
"""
Load per-dataset question pools.
If ``bank_dir`` is set, loads ``<Dataset>_questions.json`` for each domain in
``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in
the directory that declare a ``dataset`` field per record (merged lists).
If ``explicit_files`` is set, each file is loaded; records are grouped by
``dataset`` field (required for merged files).
"""
pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order}
if explicit_files:
for fp in explicit_files:
path = Path(fp)
items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
for q in items:
if q.dataset not in pools:
pools[q.dataset] = []
pools[q.dataset].append(q)
return pools
if bank_dir is None:
return pools
root = Path(bank_dir)
if not root.is_dir():
raise NotADirectoryError(f"question_bank_path must be a directory: {root}")
# Per-dataset convention: PSML_questions.json etc.
for domain in domain_order:
candidates = [
root / f"{domain}_questions.json",
root / f"{domain.lower()}_questions.json",
]
for c in candidates:
if c.is_file():
pools[domain].extend(load_json_file(c))
break
# Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests)
for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")):
if path.name in ("manifest.json", "build_manifest.json"):
continue
if any(path.name == f"{d}_questions.json" for d in domain_order):
continue
if any(path.name == f"{d.lower()}_questions.json" for d in domain_order):
continue
items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
for q in items:
key = q.dataset
if key not in pools:
pools[key] = []
pools[key].append(q)
return pools