File size: 4,122 Bytes
d954568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Load question banks from JSON or JSONL files."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from .question import TSQuestion

# Canonical domain keys used by EpisodeSampler (must match bank files or dataset field)
DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")


def _parse_records(raw: Any) -> list[dict[str, Any]]:
    if isinstance(raw, list):
        return [x for x in raw if isinstance(x, dict)]
    if isinstance(raw, dict) and "questions" in raw:
        q = raw["questions"]
        if isinstance(q, list):
            return [x for x in q if isinstance(x, dict)]
    raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}")


def _record_to_question(obj: dict[str, Any]) -> TSQuestion:
    return TSQuestion.model_validate(obj)


def load_json_file(path: Path) -> list[TSQuestion]:
    """Load a single .json file (array or {\"questions\": [...]})."""
    raw = json.loads(path.read_text(encoding="utf-8"))
    records = _parse_records(raw)
    return [_record_to_question(r) for r in records]


def load_jsonl_file(path: Path) -> list[TSQuestion]:
    """Load newline-delimited JSON; each line must be a full TSQuestion object."""
    out: list[TSQuestion] = []
    for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
        line = line.strip()
        if not line:
            continue
        try:
            obj = json.loads(line)
        except json.JSONDecodeError as e:
            raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e
        if not isinstance(obj, dict):
            raise ValueError(f"{path}:{line_no}: expected object per line")
        out.append(_record_to_question(obj))
    return out


def load_question_banks(
    bank_dir: Path | str | None,
    *,
    domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER,
    explicit_files: list[Path | str] | None = None,
) -> dict[str, list[TSQuestion]]:
    """
    Load per-dataset question pools.

    If ``bank_dir`` is set, loads ``<Dataset>_questions.json`` for each domain in
    ``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in
    the directory that declare a ``dataset`` field per record (merged lists).

    If ``explicit_files`` is set, each file is loaded; records are grouped by
    ``dataset`` field (required for merged files).
    """
    pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order}

    if explicit_files:
        for fp in explicit_files:
            path = Path(fp)
            items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
            for q in items:
                if q.dataset not in pools:
                    pools[q.dataset] = []
                pools[q.dataset].append(q)
        return pools

    if bank_dir is None:
        return pools

    root = Path(bank_dir)
    if not root.is_dir():
        raise NotADirectoryError(f"question_bank_path must be a directory: {root}")

    # Per-dataset convention: PSML_questions.json etc.
    for domain in domain_order:
        candidates = [
            root / f"{domain}_questions.json",
            root / f"{domain.lower()}_questions.json",
        ]
        for c in candidates:
            if c.is_file():
                pools[domain].extend(load_json_file(c))
                break

    # Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests)
    for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")):
        if path.name in ("manifest.json", "build_manifest.json"):
            continue
        if any(path.name == f"{d}_questions.json" for d in domain_order):
            continue
        if any(path.name == f"{d.lower()}_questions.json" for d in domain_order):
            continue
        items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
        for q in items:
            key = q.dataset
            if key not in pools:
                pools[key] = []
            pools[key].append(q)

    return pools