yashu2000 commited on
Commit
d954568
·
verified ·
1 Parent(s): 5e7b72a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/banks/PSML_questions.json filter=lfs diff=lfs merge=lfs -text
37
+ data/banks/causal_chambers_questions.json filter=lfs diff=lfs merge=lfs -text
38
+ data/banks/freshretailnet_questions.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=TemporalBenchEnv
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Copy the question bank
69
+ # COPY TS-benchmark/artifacts/temporal_bench_openenv_banks /app/env/data/banks
70
+ ENV TEMPORALBENCH_QUESTION_BANK_DIR=/app/env/data/banks
71
+
72
+ # Set PATH to use the virtual environment
73
+ ENV PATH="/app/.venv/bin:$PATH"
74
+
75
+ # Set PYTHONPATH so imports work correctly
76
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
77
+
78
+ # Health check
79
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
80
+ CMD curl -f http://localhost:8000/health || exit 1
81
+
82
+ # Run the FastAPI server
83
+ # The module path is constructed to work with the /app/env structure
84
+ ENV ENABLE_WEB_INTERFACE=true
85
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,201 @@
1
  ---
2
- title: TemporalBenchEnv
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TemporalBenchEnv MCQ Server
3
+ emoji: 🥁
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # TemporalBenchEnv
15
+
16
+ OpenEnv environment for **multi-step multiple-choice** time-series reasoning. Each episode samples nine questions from pre-built JSON banks (per-dataset files or merged JSONL in `TSQuestion` shape). Rewards combine per-step correctness and an episode bonus (see `env/reward.py`).
17
+
18
+ ## Question bank layout
19
+
20
+ Point the server at a directory containing `PSML_questions.json`, `freshretailnet_questions.json`, `MIMIC_questions.json`, and `causal_chambers_questions.json` (each file is a JSON array of `TSQuestion` records), or set **`TEMPORALBENCH_QUESTION_BANK_DIR`** to that path. If unset, the server uses `tests/fixtures/banks` when present (for local smoke runs).
21
+
22
+ Each record must include at least: `question_id`, `dataset`, `task_type` (`T1U` | `T3` | `T2_MCQ`), `prompt`, `options` (length ≥ 2), `answer`, plus optional `family`, `capability_tags`, `difficulty`, `metadata`.
23
+
24
+ ## Quick Start
25
+
26
+ Use the typed client (`TemporalBenchEnvClient`; alias `TemporalbenchenvEnv`):
27
+
28
+ ```python
29
+ from client import TemporalBenchAction, TemporalBenchEnvClient
30
+
31
+ try:
32
+ env = TemporalBenchEnvClient.from_docker_image("TemporalBenchEnv-env:latest")
33
+ out = env.reset()
34
+ while not out.done:
35
+ q = out.observation
36
+ # Agent picks q.options[i] or equivalent label string
37
+ out = env.step(TemporalBenchAction(answer=q.options[0]))
38
+ finally:
39
+ env.close()
40
+ ```
41
+
42
+ `TemporalBenchEnvClient.from_docker_image()` handles:
43
+ - Starting the Docker container
44
+ - Waiting for the server to be ready
45
+ - Connecting to the environment
46
+ - Container cleanup when you call `close()`
47
+
48
+ ## Building the Docker Image
49
+
50
+ Before using the environment, you need to build the Docker image:
51
+
52
+ ```bash
53
+ # From project root
54
+ docker build -t TemporalBenchEnv-env:latest -f server/Dockerfile .
55
+ ```
56
+
57
+ ## Deploying to Hugging Face Spaces
58
+
59
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
60
+
61
+ ```bash
62
+ # From the environment directory (where openenv.yaml is located)
63
+ openenv push
64
+
65
+ # Or specify options
66
+ openenv push --namespace my-org --private
67
+ ```
68
+
69
+ The `openenv push` command will:
70
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
71
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
72
+ 3. Upload to Hugging Face (ensuring you're logged in)
73
+
74
+ ### Prerequisites
75
+
76
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
77
+
78
+ ### Options
79
+
80
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
81
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
82
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
83
+ - `--private`: Deploy the space as private (default: public)
84
+
85
+ ### Examples
86
+
87
+ ```bash
88
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
89
+ openenv push
90
+
91
+ # Push to a specific repository
92
+ openenv push --repo-id my-org/my-env
93
+
94
+ # Push with a custom base image
95
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
96
+
97
+ # Push as a private space
98
+ openenv push --private
99
+
100
+ # Combine options
101
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
102
+ ```
103
+
104
+ After deployment, your space will be available at:
105
+ `https://huggingface.co/spaces/<repo-id>`
106
+
107
+ The deployed space includes:
108
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
109
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
110
+ - **Health Check** at `/health` - Container health monitoring
111
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
112
+
113
+ ## Environment Details
114
+
115
+ ### Action (`TemporalBenchAction`)
116
+ - `answer` (str) — MCQ label (must match ground truth after optional normalization)
117
+ - `confidence`, `reasoning` — optional
118
+
119
+ ### Observation (`TemporalBenchObservation`)
120
+ - `question`, `options`, `task_type`, `dataset`, `history`, `accuracy_so_far`
121
+ - `step_idx`, `steps_remaining`, `max_steps`, `done`, `reward`, `metadata`
122
+
123
+ ### Reward
124
+ - Per step: `alpha * correctness` (correctness 0 or 1).
125
+ - On the final step, adds episode bonus: `lambda_ep * (total_correct / N) * coverage_multiplier` (1.0 if every dataset in the episode has at least one correct answer, else 0.8).
126
+
127
+ ## Advanced Usage
128
+
129
+ ### Connecting to an Existing Server
130
+
131
+ If you already have a TemporalBenchEnv server running, connect with:
132
+
133
+ ```python
134
+ from client import TemporalBenchAction, TemporalBenchEnvClient
135
+
136
+ with TemporalBenchEnvClient(base_url="http://localhost:8000") as env:
137
+ r = env.reset()
138
+ r = env.step(TemporalBenchAction(answer=r.observation.options[0]))
139
+ ```
140
+
141
+ Note: `close()` does not stop a remote server you attached to with `base_url=...`.
142
+
143
+ ### Using the Context Manager
144
+
145
+ The client supports context manager usage for automatic connection management:
146
+
147
+ ```python
148
+ from client import TemporalBenchAction, TemporalBenchEnvClient
149
+
150
+ with TemporalBenchEnvClient(base_url="http://localhost:8000") as env:
151
+ result = env.reset()
152
+ while not result.done:
153
+ ans = result.observation.options[0]
154
+ result = env.step(TemporalBenchAction(answer=ans))
155
+ ```
156
+
157
+ The client uses WebSocket connections for:
158
+ - **Lower latency**: No HTTP connection overhead per request
159
+ - **Persistent session**: Server maintains your environment state
160
+ - **Efficient for episodes**: Better for many sequential steps
161
+
162
+ ### Concurrent WebSocket Sessions
163
+
164
+ The server uses **factory mode** (`create_app(_env_factory, ...)`) so each WebSocket session gets a fresh `TemporalBenchEnvironment`. Tune `max_concurrent_envs` in `server/app.py` as needed.
165
+
166
+ ## Development & Testing
167
+
168
+ ### Direct environment testing
169
+
170
+ ```bash
171
+ uv sync --extra dev
172
+ uv run pytest tests/
173
+ ```
174
+
175
+ ### Running Locally
176
+
177
+ Run the server locally for development:
178
+
179
+ ```bash
180
+ uvicorn server.app:app --reload
181
+ ```
182
+
183
+ ## Project Structure
184
+
185
+ ```
186
+ TemporalBenchEnv/
187
+ ├── .dockerignore # Docker build exclusions
188
+ ├── __init__.py # Module exports
189
+ ├── README.md # This file
190
+ ├── openenv.yaml # OpenEnv manifest
191
+ ├── pyproject.toml # Project metadata and dependencies
192
+ ├── uv.lock # Locked dependencies (generated)
193
+ ├── client.py # TemporalBenchEnvClient (alias TemporalbenchenvEnv)
194
+ ├── models.py # Action / observation / state re-exports
195
+ ├── env/ # Environment, sampler, grading, rewards
196
+ ├── data/ # TSQuestion schema + JSON/JSONL loaders
197
+ └── server/
198
+ ├── __init__.py # Server module exports
199
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
200
+ └── Dockerfile # Container image definition
201
+ ```
__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """TemporalBenchEnv — MCQ time-series reasoning for OpenEnv."""
8
+
9
+ from client import TemporalBenchEnvClient, TemporalbenchenvEnv
10
+ from models import (
11
+ TemporalBenchAction,
12
+ TemporalBenchObservation,
13
+ TemporalBenchState,
14
+ )
15
+
16
+ __all__ = [
17
+ "TemporalBenchAction",
18
+ "TemporalBenchObservation",
19
+ "TemporalBenchState",
20
+ "TemporalBenchEnvClient",
21
+ "TemporalbenchenvEnv",
22
+ ]
client.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Typed OpenEnv client for TemporalBenchEnv."""
8
+
9
+ from typing import Any, Dict
10
+
11
+ from openenv.core.client_types import StepResult
12
+ from openenv.core.env_client import EnvClient
13
+
14
+ try:
15
+ from env.models import (
16
+ TemporalBenchAction,
17
+ TemporalBenchObservation,
18
+ TemporalBenchState,
19
+ )
20
+ except ImportError:
21
+ from TemporalBenchEnv.env.models import (
22
+ TemporalBenchAction,
23
+ TemporalBenchObservation,
24
+ TemporalBenchState,
25
+ )
26
+
27
+
28
+ class TemporalBenchEnvClient(
29
+ EnvClient[
30
+ TemporalBenchAction,
31
+ TemporalBenchObservation,
32
+ TemporalBenchState,
33
+ ]
34
+ ):
35
+ """WebSocket client for TemporalBench MCQ episodes."""
36
+
37
+ def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]:
38
+ payload: Dict[str, Any] = {"answer": action.answer}
39
+ if action.confidence is not None:
40
+ payload["confidence"] = action.confidence
41
+ if action.reasoning is not None:
42
+ payload["reasoning"] = action.reasoning
43
+ return payload
44
+
45
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]:
46
+ obs_data = payload.get("observation")
47
+ if not isinstance(obs_data, dict):
48
+ obs_data = payload if isinstance(payload, dict) else {}
49
+
50
+ done = payload.get("done", obs_data.get("done", False))
51
+ reward = payload.get("reward", obs_data.get("reward"))
52
+
53
+ observation = TemporalBenchObservation(
54
+ step_idx=int(obs_data.get("step_idx", 0)),
55
+ steps_remaining=int(obs_data.get("steps_remaining", 0)),
56
+ max_steps=int(obs_data.get("max_steps", 9)),
57
+ question=str(obs_data.get("question", "")),
58
+ options=list(obs_data.get("options", [])),
59
+ task_type=str(obs_data.get("task_type", "")),
60
+ dataset=str(obs_data.get("dataset", "")),
61
+ history=list(obs_data.get("history", [])),
62
+ accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)),
63
+ done=done,
64
+ reward=reward,
65
+ metadata=obs_data.get("metadata", {}),
66
+ )
67
+ return StepResult(observation=observation, reward=reward, done=done)
68
+
69
+ def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState:
70
+ state_data = payload.get("state")
71
+ if not isinstance(state_data, dict):
72
+ state_data = payload if isinstance(payload, dict) else {}
73
+
74
+ return TemporalBenchState(
75
+ episode_id=state_data.get("episode_id"),
76
+ step_count=int(state_data.get("step_count", 0)),
77
+ total_correct=int(state_data.get("total_correct", 0)),
78
+ total_questions=int(state_data.get("total_questions", 9)),
79
+ current_accuracy=float(state_data.get("current_accuracy", 0.0)),
80
+ primary_domain=str(state_data.get("primary_domain", "PSML")),
81
+ per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})),
82
+ total_reward=float(state_data.get("total_reward", 0.0)),
83
+ )
84
+
85
+
86
+ TemporalbenchenvEnv = TemporalBenchEnvClient
data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Question bank types and loaders for TemporalBenchEnv."""
2
+
3
+ from .question import TSQuestion
4
+ from .loaders import load_question_banks
5
+
6
+ __all__ = ["TSQuestion", "load_question_banks"]
data/banks/MIMIC_questions.json ADDED
The diff for this file is too large to render. See raw diff
 
data/banks/PSML_questions.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fbe5a4b7ef698a2cf38cc1bad5e0a49e8ea2e8d5e694fed063992fb9b85202c
3
+ size 31377743
data/banks/causal_chambers_questions.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fecebc4f5dc8db88b83825bd35d0a9b003f16c8df43521d616e928104481163
3
+ size 30021480
data/banks/freshretailnet_questions.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:475410d2787c20bf3d908e4f4b67febd78ff9da3f4e2bf993ef8a7c9af31431a
3
+ size 31221095
data/banks/manifest.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "source_file": "/Users/yashaswisharma/Desktop/Academia Career/usc-gradschool/Research/AgentX/TS-Agents-Melady/TS-benchmark/task_merged_dev_with_labels_tiers.jsonl",
3
+ "input_rows": 764,
4
+ "total_ts_questions": 2775,
5
+ "per_dataset": {
6
+ "PSML": 750,
7
+ "freshretailnet": 616,
8
+ "MIMIC": 709,
9
+ "causal_chambers": 700
10
+ },
11
+ "openenv_loader": "Set TEMPORALBENCH_QUESTION_BANK_DIR to this directory."
12
+ }
data/loaders.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load question banks from JSON or JSONL files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from .question import TSQuestion
10
+
11
+ # Canonical domain keys used by EpisodeSampler (must match bank files or dataset field)
12
+ DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
13
+
14
+
15
+ def _parse_records(raw: Any) -> list[dict[str, Any]]:
16
+ if isinstance(raw, list):
17
+ return [x for x in raw if isinstance(x, dict)]
18
+ if isinstance(raw, dict) and "questions" in raw:
19
+ q = raw["questions"]
20
+ if isinstance(q, list):
21
+ return [x for x in q if isinstance(x, dict)]
22
+ raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}")
23
+
24
+
25
+ def _record_to_question(obj: dict[str, Any]) -> TSQuestion:
26
+ return TSQuestion.model_validate(obj)
27
+
28
+
29
+ def load_json_file(path: Path) -> list[TSQuestion]:
30
+ """Load a single .json file (array or {\"questions\": [...]})."""
31
+ raw = json.loads(path.read_text(encoding="utf-8"))
32
+ records = _parse_records(raw)
33
+ return [_record_to_question(r) for r in records]
34
+
35
+
36
+ def load_jsonl_file(path: Path) -> list[TSQuestion]:
37
+ """Load newline-delimited JSON; each line must be a full TSQuestion object."""
38
+ out: list[TSQuestion] = []
39
+ for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
40
+ line = line.strip()
41
+ if not line:
42
+ continue
43
+ try:
44
+ obj = json.loads(line)
45
+ except json.JSONDecodeError as e:
46
+ raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e
47
+ if not isinstance(obj, dict):
48
+ raise ValueError(f"{path}:{line_no}: expected object per line")
49
+ out.append(_record_to_question(obj))
50
+ return out
51
+
52
+
53
+ def load_question_banks(
54
+ bank_dir: Path | str | None,
55
+ *,
56
+ domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER,
57
+ explicit_files: list[Path | str] | None = None,
58
+ ) -> dict[str, list[TSQuestion]]:
59
+ """
60
+ Load per-dataset question pools.
61
+
62
+ If ``bank_dir`` is set, loads ``<Dataset>_questions.json`` for each domain in
63
+ ``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in
64
+ the directory that declare a ``dataset`` field per record (merged lists).
65
+
66
+ If ``explicit_files`` is set, each file is loaded; records are grouped by
67
+ ``dataset`` field (required for merged files).
68
+ """
69
+ pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order}
70
+
71
+ if explicit_files:
72
+ for fp in explicit_files:
73
+ path = Path(fp)
74
+ items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
75
+ for q in items:
76
+ if q.dataset not in pools:
77
+ pools[q.dataset] = []
78
+ pools[q.dataset].append(q)
79
+ return pools
80
+
81
+ if bank_dir is None:
82
+ return pools
83
+
84
+ root = Path(bank_dir)
85
+ if not root.is_dir():
86
+ raise NotADirectoryError(f"question_bank_path must be a directory: {root}")
87
+
88
+ # Per-dataset convention: PSML_questions.json etc.
89
+ for domain in domain_order:
90
+ candidates = [
91
+ root / f"{domain}_questions.json",
92
+ root / f"{domain.lower()}_questions.json",
93
+ ]
94
+ for c in candidates:
95
+ if c.is_file():
96
+ pools[domain].extend(load_json_file(c))
97
+ break
98
+
99
+ # Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests)
100
+ for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")):
101
+ if path.name in ("manifest.json", "build_manifest.json"):
102
+ continue
103
+ if any(path.name == f"{d}_questions.json" for d in domain_order):
104
+ continue
105
+ if any(path.name == f"{d.lower()}_questions.json" for d in domain_order):
106
+ continue
107
+ items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
108
+ for q in items:
109
+ key = q.dataset
110
+ if key not in pools:
111
+ pools[key] = []
112
+ pools[key].append(q)
113
+
114
+ return pools
data/question.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Canonical question record for TemporalBench MCQ episodes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class TSQuestion(BaseModel):
11
+ """One MCQ item from a pre-built question bank."""
12
+
13
+ question_id: str = Field(..., description="Unique identifier")
14
+ dataset: str = Field(
15
+ ...,
16
+ description="Source dataset, e.g. PSML, freshretailnet, MIMIC, causal_chambers",
17
+ )
18
+ task_type: str = Field(..., description="T1U | T3 | T2_MCQ")
19
+ family: str | None = Field(default=None, description="T3 family, e.g. S1:A")
20
+ prompt: str = Field(..., description="Full question text with context")
21
+ options: list[str] = Field(..., min_length=2, description="Answer choices")
22
+ answer: str = Field(..., description="Ground-truth label (matches one option after normalize)")
23
+ capability_tags: list[str] = Field(default_factory=list)
24
+ difficulty: str | None = Field(default=None)
25
+ metadata: dict[str, Any] = Field(default_factory=dict)
env/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """TemporalBenchEnv core logic."""
2
+
3
+ from .config import EnvConfig
4
+ from .temporal_bench_env import TemporalBenchEnvironment
5
+
6
+ __all__ = ["EnvConfig", "TemporalBenchEnvironment"]
env/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TemporalBenchEnv configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+
8
+ DEFAULT_DOMAINS: tuple[str, ...] = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
9
+
10
+
11
+ @dataclass
12
+ class EnvConfig:
13
+ """Episode structure, rewards, and data paths."""
14
+
15
+ num_questions: int = 9
16
+ primary_count: int = 6
17
+ secondary_count: int = 3
18
+ primary_domain: str = "PSML"
19
+ allowed_task_types: tuple[str, ...] = ("T1U", "T3", "T2_MCQ")
20
+ alpha: float = 1.0
21
+ lambda_ep: float = 0.5
22
+ partial_credit_t1u: bool = False
23
+ case_insensitive_match: bool = False
24
+ curriculum_stage: int = 3
25
+ question_bank_path: str | None = None
26
+ seed: int | None = None
27
+ forecasting_enabled: bool = False
28
+ all_domains: tuple[str, ...] = field(default_factory=lambda: DEFAULT_DOMAINS)
env/episode_sampler.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sample 9-question episodes from per-dataset pools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING
7
+
8
+ import numpy as np
9
+
10
+ from .config import DEFAULT_DOMAINS, EnvConfig
11
+
12
+ if TYPE_CHECKING:
13
+ from data.question import TSQuestion
14
+
15
+
16
+ def _curriculum_task_types(config: EnvConfig) -> set[str]:
17
+ stage = int(config.curriculum_stage)
18
+ if stage <= 1:
19
+ allowed = {"T1U"}
20
+ elif stage == 2:
21
+ allowed = {"T1U", "T3"}
22
+ else:
23
+ allowed = {"T1U", "T3", "T2_MCQ"}
24
+ return allowed.intersection(set(config.allowed_task_types))
25
+
26
+
27
+ def _filter_pool(questions: list[TSQuestion], config: EnvConfig) -> list[TSQuestion]:
28
+ allowed_types = _curriculum_task_types(config)
29
+ return [q for q in questions if q.task_type in allowed_types]
30
+
31
+
32
+ class EpisodeSampler:
33
+ """Builds stratified multi-domain MCQ episodes."""
34
+
35
+ def __init__(
36
+ self,
37
+ pools: dict[str, list[TSQuestion]],
38
+ config: EnvConfig,
39
+ rng: np.random.Generator,
40
+ ):
41
+ self._config = config
42
+ self._rng = rng
43
+ self._pools: dict[str, list[TSQuestion]] = {}
44
+ for domain, qs in pools.items():
45
+ filtered = _filter_pool(list(qs), config)
46
+ if filtered:
47
+ self._pools[domain] = filtered
48
+
49
+ def sample_episode(self) -> list[TSQuestion]:
50
+ """Return shuffled list of ``num_questions`` questions (default 9)."""
51
+ cfg = self._config
52
+ n = cfg.num_questions
53
+ primary_n = cfg.primary_count
54
+ domains = list(cfg.all_domains)
55
+ primary = cfg.primary_domain
56
+ if primary not in domains:
57
+ raise ValueError(f"primary_domain {primary!r} not in all_domains {domains}")
58
+
59
+ others = [d for d in domains if d != primary]
60
+ if len(others) != 3:
61
+ raise ValueError("Expected exactly 3 non-primary domains")
62
+
63
+ for d in domains:
64
+ if d not in self._pools or not self._pools[d]:
65
+ raise ValueError(f"No questions available for dataset {d!r} after filtering")
66
+
67
+ primary_pool = list(self._pools[primary])
68
+ chosen: list[TSQuestion] = []
69
+
70
+ chosen.extend(self._sample_primary_stratified(primary_pool, primary_n))
71
+
72
+ for d in others:
73
+ idx = int(self._rng.integers(0, len(self._pools[d])))
74
+ chosen.append(self._pools[d][idx])
75
+
76
+ if len(chosen) != n:
77
+ raise ValueError(f"Expected {n} questions, got {len(chosen)}")
78
+
79
+ order = self._rng.permutation(n)
80
+ return [chosen[i] for i in order]
81
+
82
+ def _sample_primary_stratified(self, pool: list[TSQuestion], k: int) -> list[TSQuestion]:
83
+ """Prefer spreading T3 across families; fill with uniform random."""
84
+ rng = self._rng
85
+ if len(pool) < k:
86
+ raise ValueError(f"Primary pool too small: need {k}, have {len(pool)}")
87
+
88
+ # Partition by (task_type, family or "")
89
+ buckets: dict[tuple[str, str], list[TSQuestion]] = defaultdict(list)
90
+ for q in pool:
91
+ fam = q.family or ""
92
+ buckets[(q.task_type, fam)].append(q)
93
+
94
+ # Round-robin order across T3 families first, then other keys
95
+ t3_families = sorted({fam for (tt, fam) in buckets if tt == "T3"})
96
+ rr_keys: list[tuple[str, str]] = [("T3", f) for f in t3_families]
97
+ other_keys = [key for key in buckets if key not in rr_keys]
98
+ rng.shuffle(other_keys)
99
+ key_order = rr_keys + other_keys
100
+
101
+ picks: list[TSQuestion] = []
102
+ while len(picks) < k and key_order:
103
+ progressed = False
104
+ for key in list(key_order):
105
+ if len(picks) >= k:
106
+ break
107
+ b = buckets.get(key, [])
108
+ if not b:
109
+ key_order.remove(key)
110
+ continue
111
+ idx = int(rng.integers(0, len(b)))
112
+ picks.append(b.pop(idx))
113
+ progressed = True
114
+ if not progressed:
115
+ break
116
+
117
+ # Uniform fill from remaining pool elements
118
+ remainder = [q for lst in buckets.values() for q in lst]
119
+ rng.shuffle(remainder)
120
+ for q in remainder:
121
+ if len(picks) >= k:
122
+ break
123
+ picks.append(q)
124
+
125
+ if len(picks) < k:
126
+ chosen_ids = {p.question_id for p in picks}
127
+ pool_rest = [q for q in pool if q.question_id not in chosen_ids]
128
+ need = k - len(picks)
129
+ if len(pool_rest) < need:
130
+ raise ValueError("Primary pool could not supply enough unique questions")
131
+ extra_idx = rng.choice(len(pool_rest), size=need, replace=False)
132
+ for i in extra_idx:
133
+ picks.append(pool_rest[int(i)])
134
+
135
+ return picks[:k]
136
+
137
+
138
+ def default_domains_for_config(config: EnvConfig) -> tuple[str, ...]:
139
+ """Domains tuple from config (falls back to module default)."""
140
+ return tuple(config.all_domains) if config.all_domains else DEFAULT_DOMAINS
env/grading.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MCQ answer grading."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from data.question import TSQuestion
9
+
10
+ from .config import EnvConfig
11
+
12
+
13
+ def _normalize(s: str, *, lower: bool) -> str:
14
+ s = str(s).strip()
15
+ return s.lower() if lower else s
16
+
17
+
18
+ def grade_answer(
19
+ submitted: str,
20
+ question: TSQuestion,
21
+ config: EnvConfig,
22
+ ) -> tuple[bool, float]:
23
+ """
24
+ Returns (is_fully_correct, partial_score in [0,1]).
25
+
26
+ Binary match to ``question.answer``, or equality to an option whose
27
+ normalized text equals the normalized ground-truth answer.
28
+ ``partial_credit_t1u`` is reserved for future multi-part T1U grading.
29
+ """
30
+ lower = config.case_insensitive_match
31
+ truth = _normalize(question.answer, lower=lower)
32
+ pred = _normalize(submitted, lower=lower)
33
+
34
+ if pred == truth:
35
+ return True, 1.0
36
+ for opt in question.options:
37
+ if pred == _normalize(opt, lower=lower) and _normalize(opt, lower=lower) == truth:
38
+ return True, 1.0
39
+ return False, 0.0
env/models.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv Pydantic models for TemporalBenchEnv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field
8
+
9
+ try:
10
+ from openenv.core.env_server.types import Action as _ActionBase
11
+ from openenv.core.env_server.types import Observation as _ObservationBase
12
+ from openenv.core.env_server.types import State as _StateBase
13
+ except ImportError:
14
+ _ActionBase = BaseModel
15
+ _ObservationBase = BaseModel
16
+ _StateBase = BaseModel
17
+
18
+
19
+ class TemporalBenchAction(_ActionBase):
20
+ """Agent submits an MCQ answer (optional confidence / reasoning)."""
21
+
22
+ if _ActionBase is BaseModel:
23
+ model_config = ConfigDict(extra="forbid")
24
+ metadata: dict[str, Any] = Field(default_factory=dict)
25
+
26
+ answer: str = Field(..., description="MCQ answer label matching an option")
27
+ confidence: float | None = Field(default=None, ge=0.0, le=1.0)
28
+ reasoning: str | None = Field(default=None, description="Optional chain-of-thought")
29
+
30
+
31
+ class TemporalBenchObservation(_ObservationBase):
32
+ """Current question and progress."""
33
+
34
+ if _ObservationBase is BaseModel:
35
+ model_config = ConfigDict(extra="forbid")
36
+ done: bool = Field(default=False)
37
+ reward: float | None = Field(default=None)
38
+ metadata: dict[str, Any] = Field(default_factory=dict)
39
+
40
+ step_idx: int = Field(..., ge=0)
41
+ steps_remaining: int = Field(..., ge=0)
42
+ max_steps: int = Field(default=9, ge=1)
43
+ question: str = Field(..., description="Current MCQ prompt")
44
+ options: list[str] = Field(..., description="Answer choices")
45
+ task_type: str = Field(..., description="T1U | T3 | T2_MCQ")
46
+ dataset: str = Field(..., description="Source dataset")
47
+ history: list[dict[str, Any]] = Field(default_factory=list)
48
+ accuracy_so_far: float = Field(default=0.0, ge=0.0, le=1.0)
49
+
50
+
51
+ class TemporalBenchState(_StateBase):
52
+ """Serializable environment state."""
53
+
54
+ if _StateBase is BaseModel:
55
+ model_config = ConfigDict(extra="allow")
56
+ episode_id: str | None = Field(default=None)
57
+ step_count: int = Field(default=0, ge=0)
58
+
59
+ total_correct: int = Field(default=0, ge=0)
60
+ total_questions: int = Field(default=9, ge=0)
61
+ current_accuracy: float = Field(default=0.0, ge=0.0, le=1.0)
62
+ primary_domain: str = Field(default="PSML")
63
+ per_task_type_accuracy: dict[str, float] = Field(default_factory=dict)
64
+ total_reward: float = Field(default=0.0)
env/reward.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward helpers for TemporalBenchEnv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+
10
+ def compute_mcq_reward(correctness_score: float, alpha: float = 1.0) -> float:
11
+ """Per-step reward: alpha * correctness (0..1)."""
12
+ return float(alpha) * float(correctness_score)
13
+
14
+
15
+ def compute_forecasting_reward(
16
+ predicted: np.ndarray,
17
+ ground_truth: np.ndarray,
18
+ metric: str = "mse",
19
+ **kwargs: Any,
20
+ ) -> float:
21
+ del predicted, ground_truth, metric, kwargs
22
+ raise NotImplementedError("Forecasting reward not yet implemented")
23
+
24
+
25
+ def compute_episode_bonus(
26
+ total_correct: int,
27
+ num_questions: int,
28
+ domain_coverage: dict[str, int],
29
+ *,
30
+ all_domains: tuple[str, ...],
31
+ lambda_ep: float = 0.5,
32
+ ) -> float:
33
+ """
34
+ Episode bonus: lambda_ep * episode_accuracy * domain_coverage_bonus.
35
+
36
+ ``domain_coverage`` maps dataset name -> count of *correct* answers in the
37
+ episode for that dataset. Bonus uses 1.0 if every domain in ``all_domains``
38
+ has at least one correct answer, else 0.8.
39
+ """
40
+ accuracy = total_correct / num_questions if num_questions > 0 else 0.0
41
+ counts = [int(domain_coverage.get(d, 0)) for d in all_domains]
42
+ all_covered = all(c > 0 for c in counts)
43
+ coverage_bonus = 1.0 if all_covered else 0.8
44
+ return float(lambda_ep) * float(accuracy) * coverage_bonus
env/temporal_bench_env.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core OpenEnv environment for TemporalBench MCQ episodes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from collections import defaultdict
7
+ from dataclasses import replace
8
+ from typing import Any, Optional
9
+
10
+ import numpy as np
11
+
12
+ from data.loaders import load_question_banks
13
+ from data.question import TSQuestion
14
+
15
+ from .config import EnvConfig
16
+ from .episode_sampler import EpisodeSampler
17
+ from .grading import grade_answer
18
+ from .models import TemporalBenchAction, TemporalBenchObservation, TemporalBenchState
19
+ from .reward import compute_episode_bonus, compute_mcq_reward
20
+
21
+ try:
22
+ from openenv.core.env_server.interfaces import Environment
23
+ except ImportError:
24
+ from abc import ABC, abstractmethod
25
+ from typing import Generic, TypeVar
26
+
27
+ ActT = TypeVar("ActT")
28
+ ObsT = TypeVar("ObsT")
29
+ StateT = TypeVar("StateT")
30
+
31
+ class Environment(ABC, Generic[ActT, ObsT, StateT]):
32
+ @abstractmethod
33
+ def reset(self, seed=None, episode_id=None, **kwargs): ...
34
+
35
+ @abstractmethod
36
+ def step(self, action, timeout_s=None, **kwargs): ...
37
+
38
+ @property
39
+ @abstractmethod
40
+ def state(self): ...
41
+
42
+
43
+ class TemporalBenchEnvironment(
44
+ Environment[TemporalBenchAction, TemporalBenchObservation, TemporalBenchState]
45
+ ):
46
+ """Multi-step MCQ environment over a pre-built TemporalBench question bank."""
47
+
48
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
49
+
50
+ def __init__(self, config: Optional[EnvConfig] = None, **kwargs: Any):
51
+ super().__init__(**kwargs)
52
+ self._config = config or EnvConfig()
53
+ seed = self._config.seed
54
+ self._rng = np.random.default_rng(seed)
55
+
56
+ self._banks = load_question_banks(self._config.question_bank_path)
57
+ self._sampler = EpisodeSampler(self._banks, self._config, self._rng)
58
+
59
+ self._episode_id: Optional[str] = None
60
+ self._questions: list[TSQuestion] = []
61
+ self._answered: int = 0
62
+ self._history: list[dict[str, Any]] = []
63
+ self._done: bool = False
64
+ self._total_correct: int = 0
65
+ self._total_reward: float = 0.0
66
+ self._domain_correct: dict[str, int] = defaultdict(int)
67
+ self._task_correct: dict[str, int] = defaultdict(int)
68
+ self._task_total: dict[str, int] = defaultdict(int)
69
+ self._last_metadata: dict[str, Any] = {}
70
+
71
+ def _accuracy_so_far(self) -> float:
72
+ if self._answered == 0:
73
+ return 0.0
74
+ return self._total_correct / self._answered
75
+
76
+ def _per_task_accuracy(self) -> dict[str, float]:
77
+ out: dict[str, float] = {}
78
+ for k, tot in self._task_total.items():
79
+ out[k] = (self._task_correct[k] / tot) if tot else 0.0
80
+ return out
81
+
82
+ def _build_observation(
83
+ self,
84
+ *,
85
+ reward: float | None,
86
+ done: bool,
87
+ ) -> TemporalBenchObservation:
88
+ n = self._config.num_questions
89
+ if done or self._answered >= n:
90
+ return TemporalBenchObservation(
91
+ step_idx=self._answered,
92
+ steps_remaining=0,
93
+ max_steps=n,
94
+ question="",
95
+ options=[],
96
+ task_type="",
97
+ dataset="",
98
+ history=list(self._history),
99
+ accuracy_so_far=self._accuracy_so_far(),
100
+ done=True,
101
+ reward=reward,
102
+ metadata=dict(self._last_metadata),
103
+ )
104
+
105
+ q = self._questions[self._answered]
106
+ steps_remaining = n - self._answered
107
+ return TemporalBenchObservation(
108
+ step_idx=self._answered,
109
+ steps_remaining=steps_remaining,
110
+ max_steps=n,
111
+ question=q.prompt,
112
+ options=list(q.options),
113
+ task_type=q.task_type,
114
+ dataset=q.dataset,
115
+ history=list(self._history),
116
+ accuracy_so_far=self._accuracy_so_far(),
117
+ done=False,
118
+ reward=reward,
119
+ metadata=dict(self._last_metadata),
120
+ )
121
+
122
+ def reset(
123
+ self,
124
+ seed: Optional[int] = None,
125
+ episode_id: Optional[str] = None,
126
+ **kwargs: Any,
127
+ ) -> TemporalBenchObservation:
128
+ curriculum_kw = kwargs.pop("curriculum_stage", None)
129
+
130
+ if seed is not None:
131
+ self._rng = np.random.default_rng(seed)
132
+
133
+ cfg = self._config
134
+ if curriculum_kw is not None:
135
+ cfg = replace(self._config, curriculum_stage=int(curriculum_kw))
136
+
137
+ self._sampler = EpisodeSampler(self._banks, cfg, self._rng)
138
+
139
+ self._episode_id = episode_id or str(uuid.uuid4())
140
+ self._questions = self._sampler.sample_episode()
141
+ self._answered = 0
142
+ self._history = []
143
+ self._done = False
144
+ self._total_correct = 0
145
+ self._total_reward = 0.0
146
+ self._domain_correct = defaultdict(int)
147
+ self._task_correct = defaultdict(int)
148
+ self._task_total = defaultdict(int)
149
+ self._last_metadata = {}
150
+
151
+ return self._build_observation(reward=0.0, done=False)
152
+
153
+ def step(
154
+ self,
155
+ action: TemporalBenchAction,
156
+ timeout_s: Optional[float] = None,
157
+ **kwargs: Any,
158
+ ) -> TemporalBenchObservation:
159
+ del timeout_s, kwargs
160
+ if self._done:
161
+ self._last_metadata = {"info": "Episode already done."}
162
+ return self._build_observation(reward=0.0, done=True)
163
+
164
+ self._last_metadata = {}
165
+
166
+ n = self._config.num_questions
167
+ if self._answered >= n:
168
+ self._done = True
169
+ self._last_metadata = {"info": "Episode already complete."}
170
+ return self._build_observation(reward=0.0, done=True)
171
+
172
+ q = self._questions[self._answered]
173
+ if not str(action.answer).strip():
174
+ self._last_metadata = {"error": "answer must be a non-empty string."}
175
+ return self._build_observation(reward=0.0, done=False)
176
+
177
+ fully_correct, score = grade_answer(action.answer, q, self._config)
178
+ r_step = compute_mcq_reward(score, alpha=self._config.alpha)
179
+
180
+ self._history.append(
181
+ {
182
+ "question_id": q.question_id,
183
+ "dataset": q.dataset,
184
+ "task_type": q.task_type,
185
+ "submitted": action.answer,
186
+ "correct": fully_correct,
187
+ "reward": r_step,
188
+ }
189
+ )
190
+
191
+ self._task_total[q.task_type] += 1
192
+ if fully_correct:
193
+ self._total_correct += 1
194
+ self._domain_correct[q.dataset] += 1
195
+ self._task_correct[q.task_type] += 1
196
+
197
+ self._answered += 1
198
+ total_reward_this_step = r_step
199
+
200
+ if self._answered >= n:
201
+ bonus = compute_episode_bonus(
202
+ self._total_correct,
203
+ n,
204
+ dict(self._domain_correct),
205
+ all_domains=tuple(self._config.all_domains),
206
+ lambda_ep=self._config.lambda_ep,
207
+ )
208
+ total_reward_this_step = r_step + bonus
209
+ self._done = True
210
+ self._last_metadata = {
211
+ "episode_bonus": bonus,
212
+ "domain_correct_counts": dict(self._domain_correct),
213
+ }
214
+
215
+ self._total_reward += total_reward_this_step
216
+
217
+ return self._build_observation(
218
+ reward=total_reward_this_step,
219
+ done=self._done,
220
+ )
221
+
222
+ @property
223
+ def state(self) -> TemporalBenchState:
224
+ return TemporalBenchState(
225
+ episode_id=self._episode_id,
226
+ step_count=self._answered,
227
+ total_correct=self._total_correct,
228
+ total_questions=self._config.num_questions,
229
+ current_accuracy=self._accuracy_so_far(),
230
+ primary_domain=self._config.primary_domain,
231
+ per_task_type_accuracy=self._per_task_accuracy(),
232
+ total_reward=self._total_reward,
233
+ )
models.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """OpenEnv models at package root for schema validation and server imports."""
8
+
9
+ try:
10
+ from env.models import (
11
+ TemporalBenchAction,
12
+ TemporalBenchObservation,
13
+ TemporalBenchState,
14
+ )
15
+ except ImportError:
16
+ from .env.models import (
17
+ TemporalBenchAction,
18
+ TemporalBenchObservation,
19
+ TemporalBenchState,
20
+ )
21
+
22
+ __all__ = [
23
+ "TemporalBenchAction",
24
+ "TemporalBenchObservation",
25
+ "TemporalBenchState",
26
+ ]
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: TemporalBenchEnv
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
openenv_TemporalBenchEnv.egg-info/PKG-INFO ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-TemporalBenchEnv
3
+ Version: 0.1.0
4
+ Summary: TemporalBench MCQ environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.0
7
+ Requires-Dist: numpy>=1.24.0
8
+ Requires-Dist: pydantic>=2.0.0
9
+ Provides-Extra: dev
10
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
11
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_TemporalBenchEnv.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ client.py
3
+ models.py
4
+ pyproject.toml
5
+ data/__init__.py
6
+ data/loaders.py
7
+ data/question.py
8
+ env/__init__.py
9
+ env/config.py
10
+ env/episode_sampler.py
11
+ env/grading.py
12
+ env/models.py
13
+ env/reward.py
14
+ env/temporal_bench_env.py
15
+ openenv_TemporalBenchEnv.egg-info/PKG-INFO
16
+ openenv_TemporalBenchEnv.egg-info/SOURCES.txt
17
+ openenv_TemporalBenchEnv.egg-info/dependency_links.txt
18
+ openenv_TemporalBenchEnv.egg-info/entry_points.txt
19
+ openenv_TemporalBenchEnv.egg-info/requires.txt
20
+ openenv_TemporalBenchEnv.egg-info/top_level.txt
21
+ server/__init__.py
22
+ server/app.py
23
+ tests/test_env.py
24
+ tests/test_loaders.py
25
+ tests/test_reward.py
26
+ tests/test_sampler.py
openenv_TemporalBenchEnv.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_TemporalBenchEnv.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = server.app:main
openenv_TemporalBenchEnv.egg-info/requires.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+ numpy>=1.24.0
3
+ pydantic>=2.0.0
4
+
5
+ [dev]
6
+ pytest>=8.0.0
7
+ pytest-cov>=4.0.0
openenv_TemporalBenchEnv.egg-info/top_level.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ client
2
+ data
3
+ env
4
+ models
5
+ server
pyproject.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=61.0", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-TemporalBenchEnv"
13
+ version = "0.1.0"
14
+ description = "TemporalBench MCQ environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ "openenv-core[core]>=0.2.0",
18
+ "numpy>=1.24.0",
19
+ "pydantic>=2.0.0",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ dev = [
24
+ "pytest>=8.0.0",
25
+ "pytest-cov>=4.0.0",
26
+ ]
27
+
28
+ [project.scripts]
29
+ server = "server.app:main"
30
+
31
+ [tool.setuptools]
32
+ include-package-data = true
33
+ py-modules = ["client", "models"]
34
+
35
+ [tool.setuptools.packages.find]
36
+ where = ["."]
37
+ include = ["env*", "data*", "server*"]
38
+
39
+ [tool.pytest.ini_options]
40
+ pythonpath = ["."]
41
+ testpaths = ["tests"]
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """TemporalBenchEnv environment server components."""
8
+
9
+ from env.temporal_bench_env import TemporalBenchEnvironment
10
+
11
+ __all__ = ["TemporalBenchEnvironment"]
server/app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """FastAPI application for TemporalBenchEnv."""
8
+
9
+ import os
10
+ from pathlib import Path
11
+
12
+ try:
13
+ from env.config import EnvConfig
14
+ from env.temporal_bench_env import TemporalBenchEnvironment
15
+ from models import TemporalBenchAction, TemporalBenchObservation
16
+ except ImportError:
17
+ from ..env.config import EnvConfig
18
+ from ..env.temporal_bench_env import TemporalBenchEnvironment
19
+ from ..models import TemporalBenchAction, TemporalBenchObservation
20
+
21
+ try:
22
+ from openenv.core.env_server import create_app
23
+ except ImportError:
24
+ create_app = None # type: ignore
25
+
26
+
27
+ def _env_factory():
28
+ """Create a fresh environment instance per WebSocket session."""
29
+ bank_dir = os.environ.get("TEMPORALBENCH_QUESTION_BANK_DIR")
30
+ if not bank_dir:
31
+ default = Path(__file__).resolve().parents[1] / "tests" / "fixtures" / "banks"
32
+ if default.is_dir():
33
+ bank_dir = str(default)
34
+ cfg = EnvConfig(question_bank_path=bank_dir) if bank_dir else EnvConfig()
35
+ return TemporalBenchEnvironment(config=cfg)
36
+
37
+
38
+ if create_app is not None:
39
+ app = create_app(
40
+ _env_factory,
41
+ TemporalBenchAction,
42
+ TemporalBenchObservation,
43
+ env_name="temporal-bench-env",
44
+ max_concurrent_envs=64,
45
+ )
46
+ else:
47
+ from fastapi import FastAPI
48
+
49
+ app = FastAPI(title="temporal-bench-env")
50
+ app.get("/health")(lambda: {"status": "ok"})
51
+
52
+
53
+ def main(host: str | None = None, port: int | None = None) -> None:
54
+ """
55
+ Entry point for `uv run server` and OpenEnv multi-mode validation.
56
+
57
+ OpenEnv's validator does a naive substring check for ``main()`` in this
58
+ file, so the ``if __name__ == "__main__"`` block must call ``main()`` with
59
+ no arguments; CLI flags are parsed here via ``parse_known_args``.
60
+ """
61
+ import argparse
62
+
63
+ import uvicorn
64
+
65
+ if host is None or port is None:
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--host", type=str, default="0.0.0.0")
68
+ parser.add_argument("--port", type=int, default=8000)
69
+ ns, _ = parser.parse_known_args()
70
+ if host is None:
71
+ host = ns.host
72
+ if port is None:
73
+ port = ns.port
74
+
75
+ uvicorn.run(app, host=host, port=port)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
tests/conftest.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ FIXTURE_BANK_DIR = Path(__file__).resolve().parent / "fixtures" / "banks"
6
+
7
+
8
+ @pytest.fixture
9
+ def fixture_bank_dir() -> Path:
10
+ assert FIXTURE_BANK_DIR.is_dir()
11
+ return FIXTURE_BANK_DIR
tests/fixtures/banks/MIMIC_questions.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [
2
+ {"question_id": "mim-1", "dataset": "MIMIC", "task_type": "T1U", "family": null, "prompt": "M1", "options": ["0", "1"], "answer": "0", "capability_tags": [], "difficulty": null, "metadata": {}},
3
+ {"question_id": "mim-2", "dataset": "MIMIC", "task_type": "T2_MCQ", "family": null, "prompt": "M2", "options": ["0", "1"], "answer": "1", "capability_tags": [], "difficulty": null, "metadata": {}}
4
+ ]
tests/fixtures/banks/PSML_questions.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {"question_id": "psml-t3-s1a-1", "dataset": "PSML", "task_type": "T3", "family": "S1:A", "prompt": "Q1", "options": ["a", "b"], "answer": "a", "capability_tags": ["C1"], "difficulty": "easy", "metadata": {}},
3
+ {"question_id": "psml-t3-s2b-1", "dataset": "PSML", "task_type": "T3", "family": "S2:B", "prompt": "Q2", "options": ["a", "b"], "answer": "b", "capability_tags": ["C2"], "difficulty": "easy", "metadata": {}},
4
+ {"question_id": "psml-t3-s3c-1", "dataset": "PSML", "task_type": "T3", "family": "S3:C", "prompt": "Q3", "options": ["x", "y"], "answer": "x", "capability_tags": [], "difficulty": null, "metadata": {}},
5
+ {"question_id": "psml-t1u-1", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q4 trend?", "options": ["upward", "downward"], "answer": "upward", "capability_tags": [], "difficulty": null, "metadata": {}},
6
+ {"question_id": "psml-t2-1", "dataset": "PSML", "task_type": "T2_MCQ", "family": null, "prompt": "Q5", "options": ["m1", "m2"], "answer": "m1", "capability_tags": [], "difficulty": null, "metadata": {}},
7
+ {"question_id": "psml-t3-s4d-1", "dataset": "PSML", "task_type": "T3", "family": "S4:D", "prompt": "Q6", "options": ["p", "q"], "answer": "p", "capability_tags": [], "difficulty": null, "metadata": {}},
8
+ {"question_id": "psml-t1u-2", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7", "options": ["yes", "no"], "answer": "yes", "capability_tags": [], "difficulty": null, "metadata": {}},
9
+ {"question_id": "psml-t1u-3", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7b", "options": ["yes", "no"], "answer": "no", "capability_tags": [], "difficulty": null, "metadata": {}},
10
+ {"question_id": "psml-t1u-4", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7c", "options": ["u", "d"], "answer": "u", "capability_tags": [], "difficulty": null, "metadata": {}},
11
+ {"question_id": "psml-t1u-5", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7d", "options": ["u", "d"], "answer": "d", "capability_tags": [], "difficulty": null, "metadata": {}},
12
+ {"question_id": "psml-t1u-6", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7e", "options": ["e", "f"], "answer": "e", "capability_tags": [], "difficulty": null, "metadata": {}},
13
+ {"question_id": "psml-t2-2", "dataset": "PSML", "task_type": "T2_MCQ", "family": null, "prompt": "Q8", "options": ["h", "l"], "answer": "h", "capability_tags": [], "difficulty": null, "metadata": {}}
14
+ ]
tests/fixtures/banks/causal_chambers_questions.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [
2
+ {"question_id": "cc-1", "dataset": "causal_chambers", "task_type": "T3", "family": "S5:E", "prompt": "CC1", "options": ["low", "high"], "answer": "low", "capability_tags": [], "difficulty": null, "metadata": {}},
3
+ {"question_id": "cc-2", "dataset": "causal_chambers", "task_type": "T1U", "family": null, "prompt": "CC2", "options": ["low", "high"], "answer": "high", "capability_tags": [], "difficulty": null, "metadata": {}}
4
+ ]
tests/fixtures/banks/freshretailnet_questions.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [
2
+ {"question_id": "frn-1", "dataset": "freshretailnet", "task_type": "T1U", "family": null, "prompt": "FRN1", "options": ["A", "B"], "answer": "A", "capability_tags": [], "difficulty": null, "metadata": {}},
3
+ {"question_id": "frn-2", "dataset": "freshretailnet", "task_type": "T3", "family": "S1:A", "prompt": "FRN2", "options": ["A", "B"], "answer": "B", "capability_tags": [], "difficulty": null, "metadata": {}}
4
+ ]
tests/test_env.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from env.config import EnvConfig
6
+ from env.temporal_bench_env import TemporalBenchEnvironment
7
+ from models import TemporalBenchAction
8
+
9
+
10
+ def _bank_dir() -> Path:
11
+ return Path(__file__).resolve().parent / "fixtures" / "banks"
12
+
13
+
14
+ def test_full_episode_all_correct():
15
+ cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0)
16
+ env = TemporalBenchEnvironment(config=cfg)
17
+ obs0 = env.reset(seed=123)
18
+ assert not obs0.done
19
+ assert obs0.step_idx == 0
20
+ assert len(obs0.options) >= 2
21
+
22
+ total_r = 0.0
23
+ obs = obs0
24
+ for i in range(cfg.num_questions):
25
+ cur = env._questions[env._answered] # noqa: SLF001
26
+ obs = env.step(TemporalBenchAction(answer=cur.answer))
27
+ total_r += obs.reward or 0.0
28
+ if i < cfg.num_questions - 1:
29
+ assert not obs.done
30
+ assert obs.step_idx == i + 1
31
+ else:
32
+ assert obs.done
33
+ assert obs.step_idx == cfg.num_questions
34
+
35
+ st = env.state
36
+ assert st.total_correct == cfg.num_questions
37
+ assert st.step_count == cfg.num_questions
38
+ assert st.total_reward == pytest.approx(total_r)
39
+
40
+
41
+ def test_episode_with_one_wrong_answer():
42
+ cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0)
43
+ env = TemporalBenchEnvironment(config=cfg)
44
+ env.reset(seed=0)
45
+ for i in range(cfg.num_questions):
46
+ cur = env._questions[env._answered] # noqa: SLF001
47
+ wrong = next(o for o in cur.options if o != cur.answer)
48
+ act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer)
49
+ obs = env.step(act)
50
+ if i == cfg.num_questions - 1:
51
+ assert obs.done
52
+ meta = obs.metadata or {}
53
+ assert "episode_bonus" in meta
54
+ bonus = float(meta["episode_bonus"])
55
+ lo = 0.5 * (8 / 9) * 0.8
56
+ hi = 0.5 * (8 / 9) * 1.0
57
+ assert lo - 1e-9 <= bonus <= hi + 1e-9
58
+ assert env.state.total_correct == 8
59
+
60
+
61
+ def test_double_step_after_done():
62
+ cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1)
63
+ env = TemporalBenchEnvironment(config=cfg)
64
+ env.reset(seed=1)
65
+ for _ in range(cfg.num_questions):
66
+ cur = env._questions[env._answered] # noqa: SLF001
67
+ env.step(TemporalBenchAction(answer=cur.answer))
68
+ obs = env.step(TemporalBenchAction(answer="x"))
69
+ assert obs.done
70
+
71
+
72
+ def test_empty_answer_no_advance():
73
+ cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2)
74
+ env = TemporalBenchEnvironment(config=cfg)
75
+ env.reset(seed=2)
76
+ before = env._answered # noqa: SLF001
77
+ obs = env.step(TemporalBenchAction(answer=" "))
78
+ assert env._answered == before # noqa: SLF001
79
+ assert "error" in (obs.metadata or {})
tests/test_loaders.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from data.loaders import load_json_file, load_question_banks
6
+ from data.question import TSQuestion
7
+
8
+
9
+ def test_load_json_file_roundtrip(fixture_bank_dir: Path):
10
+ path = fixture_bank_dir / "PSML_questions.json"
11
+ qs = load_json_file(path)
12
+ assert len(qs) >= 6
13
+ assert all(isinstance(q, TSQuestion) for q in qs)
14
+ assert qs[0].dataset == "PSML"
15
+
16
+
17
+ def test_load_question_banks_directory(fixture_bank_dir: Path):
18
+ pools = load_question_banks(str(fixture_bank_dir))
19
+ for d in ("PSML", "freshretailnet", "MIMIC", "causal_chambers"):
20
+ assert d in pools
21
+ assert len(pools[d]) >= 1
22
+
23
+
24
+ def test_load_question_banks_rejects_non_dir():
25
+ with pytest.raises(NotADirectoryError):
26
+ load_question_banks("/nonexistent/path/that/is/not/a/dir")
tests/test_reward.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from env.reward import compute_episode_bonus, compute_forecasting_reward, compute_mcq_reward
4
+
5
+
6
+ def test_compute_mcq_reward():
7
+ assert compute_mcq_reward(1.0, alpha=2.0) == 2.0
8
+ assert compute_mcq_reward(0.0) == 0.0
9
+
10
+
11
+ def test_compute_episode_bonus_coverage():
12
+ domains = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
13
+ cov = {"PSML": 1, "freshretailnet": 0, "MIMIC": 1, "causal_chambers": 1}
14
+ b = compute_episode_bonus(3, 9, cov, all_domains=domains, lambda_ep=0.5)
15
+ assert b == 0.5 * (3 / 9) * 0.8
16
+
17
+ cov2 = {"PSML": 1, "freshretailnet": 1, "MIMIC": 1, "causal_chambers": 1}
18
+ b2 = compute_episode_bonus(4, 9, cov2, all_domains=domains, lambda_ep=0.5)
19
+ assert b2 == 0.5 * (4 / 9) * 1.0
20
+
21
+
22
+ def test_forecasting_reward_stub():
23
+ import numpy as np
24
+
25
+ with pytest.raises(NotImplementedError):
26
+ compute_forecasting_reward(np.array([1.0]), np.array([1.0]))
tests/test_sampler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from data.loaders import load_question_banks
4
+ from env.config import EnvConfig
5
+ from env.episode_sampler import EpisodeSampler
6
+
7
+
8
+ def test_sample_episode_shape_and_domains(fixture_bank_dir):
9
+ banks = load_question_banks(str(fixture_bank_dir))
10
+ cfg = EnvConfig(
11
+ question_bank_path=str(fixture_bank_dir),
12
+ primary_domain="PSML",
13
+ seed=42,
14
+ )
15
+ rng = np.random.default_rng(42)
16
+ sampler = EpisodeSampler(banks, cfg, rng)
17
+ ep = sampler.sample_episode()
18
+ assert len(ep) == cfg.num_questions
19
+ datasets = {q.dataset for q in ep}
20
+ assert datasets == set(cfg.all_domains)
21
+
22
+
23
+ def test_curriculum_stage_filters_task_types(fixture_bank_dir):
24
+ banks = load_question_banks(str(fixture_bank_dir))
25
+ cfg = EnvConfig(
26
+ question_bank_path=str(fixture_bank_dir),
27
+ curriculum_stage=1,
28
+ seed=0,
29
+ )
30
+ rng = np.random.default_rng(0)
31
+ sampler = EpisodeSampler(banks, cfg, rng)
32
+ for _ in range(5):
33
+ ep = sampler.sample_episode()
34
+ assert all(q.task_type == "T1U" for q in ep)
uv.lock ADDED
The diff for this file is too large to render. See raw diff