Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- Dockerfile +85 -0
- README.md +196 -5
- __init__.py +22 -0
- client.py +86 -0
- data/__init__.py +6 -0
- data/banks/MIMIC_questions.json +0 -0
- data/banks/PSML_questions.json +3 -0
- data/banks/causal_chambers_questions.json +3 -0
- data/banks/freshretailnet_questions.json +3 -0
- data/banks/manifest.json +12 -0
- data/loaders.py +114 -0
- data/question.py +25 -0
- env/__init__.py +6 -0
- env/config.py +28 -0
- env/episode_sampler.py +140 -0
- env/grading.py +39 -0
- env/models.py +64 -0
- env/reward.py +44 -0
- env/temporal_bench_env.py +233 -0
- models.py +26 -0
- openenv.yaml +7 -0
- openenv_TemporalBenchEnv.egg-info/PKG-INFO +11 -0
- openenv_TemporalBenchEnv.egg-info/SOURCES.txt +26 -0
- openenv_TemporalBenchEnv.egg-info/dependency_links.txt +1 -0
- openenv_TemporalBenchEnv.egg-info/entry_points.txt +2 -0
- openenv_TemporalBenchEnv.egg-info/requires.txt +7 -0
- openenv_TemporalBenchEnv.egg-info/top_level.txt +5 -0
- pyproject.toml +41 -0
- server/__init__.py +11 -0
- server/app.py +79 -0
- server/requirements.txt +6 -0
- tests/conftest.py +11 -0
- tests/fixtures/banks/MIMIC_questions.json +4 -0
- tests/fixtures/banks/PSML_questions.json +14 -0
- tests/fixtures/banks/causal_chambers_questions.json +4 -0
- tests/fixtures/banks/freshretailnet_questions.json +4 -0
- tests/test_env.py +79 -0
- tests/test_loaders.py +26 -0
- tests/test_reward.py +26 -0
- tests/test_sampler.py +34 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/banks/PSML_questions.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/banks/causal_chambers_questions.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/banks/freshretailnet_questions.json filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=TemporalBenchEnv
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Copy the question bank
|
| 69 |
+
# COPY TS-benchmark/artifacts/temporal_bench_openenv_banks /app/env/data/banks
|
| 70 |
+
ENV TEMPORALBENCH_QUESTION_BANK_DIR=/app/env/data/banks
|
| 71 |
+
|
| 72 |
+
# Set PATH to use the virtual environment
|
| 73 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 74 |
+
|
| 75 |
+
# Set PYTHONPATH so imports work correctly
|
| 76 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 77 |
+
|
| 78 |
+
# Health check
|
| 79 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 80 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 81 |
+
|
| 82 |
+
# Run the FastAPI server
|
| 83 |
+
# The module path is constructed to work with the /app/env structure
|
| 84 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 85 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,201 @@
|
|
| 1 |
---
|
| 2 |
-
title: TemporalBenchEnv
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TemporalBenchEnv MCQ Server
|
| 3 |
+
emoji: 🥁
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# TemporalBenchEnv
|
| 15 |
+
|
| 16 |
+
OpenEnv environment for **multi-step multiple-choice** time-series reasoning. Each episode samples nine questions from pre-built JSON banks (per-dataset files or merged JSONL in `TSQuestion` shape). Rewards combine per-step correctness and an episode bonus (see `env/reward.py`).
|
| 17 |
+
|
| 18 |
+
## Question bank layout
|
| 19 |
+
|
| 20 |
+
Point the server at a directory containing `PSML_questions.json`, `freshretailnet_questions.json`, `MIMIC_questions.json`, and `causal_chambers_questions.json` (each file is a JSON array of `TSQuestion` records), or set **`TEMPORALBENCH_QUESTION_BANK_DIR`** to that path. If unset, the server uses `tests/fixtures/banks` when present (for local smoke runs).
|
| 21 |
+
|
| 22 |
+
Each record must include at least: `question_id`, `dataset`, `task_type` (`T1U` | `T3` | `T2_MCQ`), `prompt`, `options` (length ≥ 2), `answer`, plus optional `family`, `capability_tags`, `difficulty`, `metadata`.
|
| 23 |
+
|
| 24 |
+
## Quick Start
|
| 25 |
+
|
| 26 |
+
Use the typed client (`TemporalBenchEnvClient`; alias `TemporalbenchenvEnv`):
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
from client import TemporalBenchAction, TemporalBenchEnvClient
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
env = TemporalBenchEnvClient.from_docker_image("TemporalBenchEnv-env:latest")
|
| 33 |
+
out = env.reset()
|
| 34 |
+
while not out.done:
|
| 35 |
+
q = out.observation
|
| 36 |
+
# Agent picks q.options[i] or equivalent label string
|
| 37 |
+
out = env.step(TemporalBenchAction(answer=q.options[0]))
|
| 38 |
+
finally:
|
| 39 |
+
env.close()
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
`TemporalBenchEnvClient.from_docker_image()` handles:
|
| 43 |
+
- Starting the Docker container
|
| 44 |
+
- Waiting for the server to be ready
|
| 45 |
+
- Connecting to the environment
|
| 46 |
+
- Container cleanup when you call `close()`
|
| 47 |
+
|
| 48 |
+
## Building the Docker Image
|
| 49 |
+
|
| 50 |
+
Before using the environment, you need to build the Docker image:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
# From project root
|
| 54 |
+
docker build -t TemporalBenchEnv-env:latest -f server/Dockerfile .
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Deploying to Hugging Face Spaces
|
| 58 |
+
|
| 59 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
# From the environment directory (where openenv.yaml is located)
|
| 63 |
+
openenv push
|
| 64 |
+
|
| 65 |
+
# Or specify options
|
| 66 |
+
openenv push --namespace my-org --private
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
The `openenv push` command will:
|
| 70 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 71 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 72 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 73 |
+
|
| 74 |
+
### Prerequisites
|
| 75 |
+
|
| 76 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 77 |
+
|
| 78 |
+
### Options
|
| 79 |
+
|
| 80 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 81 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 82 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 83 |
+
- `--private`: Deploy the space as private (default: public)
|
| 84 |
+
|
| 85 |
+
### Examples
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 89 |
+
openenv push
|
| 90 |
+
|
| 91 |
+
# Push to a specific repository
|
| 92 |
+
openenv push --repo-id my-org/my-env
|
| 93 |
+
|
| 94 |
+
# Push with a custom base image
|
| 95 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 96 |
+
|
| 97 |
+
# Push as a private space
|
| 98 |
+
openenv push --private
|
| 99 |
+
|
| 100 |
+
# Combine options
|
| 101 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
After deployment, your space will be available at:
|
| 105 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 106 |
+
|
| 107 |
+
The deployed space includes:
|
| 108 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 109 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 110 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 111 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 112 |
+
|
| 113 |
+
## Environment Details
|
| 114 |
+
|
| 115 |
+
### Action (`TemporalBenchAction`)
|
| 116 |
+
- `answer` (str) — MCQ label (must match ground truth after optional normalization)
|
| 117 |
+
- `confidence`, `reasoning` — optional
|
| 118 |
+
|
| 119 |
+
### Observation (`TemporalBenchObservation`)
|
| 120 |
+
- `question`, `options`, `task_type`, `dataset`, `history`, `accuracy_so_far`
|
| 121 |
+
- `step_idx`, `steps_remaining`, `max_steps`, `done`, `reward`, `metadata`
|
| 122 |
+
|
| 123 |
+
### Reward
|
| 124 |
+
- Per step: `alpha * correctness` (correctness 0 or 1).
|
| 125 |
+
- On the final step, adds episode bonus: `lambda_ep * (total_correct / N) * coverage_multiplier` (1.0 if every dataset in the episode has at least one correct answer, else 0.8).
|
| 126 |
+
|
| 127 |
+
## Advanced Usage
|
| 128 |
+
|
| 129 |
+
### Connecting to an Existing Server
|
| 130 |
+
|
| 131 |
+
If you already have a TemporalBenchEnv server running, connect with:
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
from client import TemporalBenchAction, TemporalBenchEnvClient
|
| 135 |
+
|
| 136 |
+
with TemporalBenchEnvClient(base_url="http://localhost:8000") as env:
|
| 137 |
+
r = env.reset()
|
| 138 |
+
r = env.step(TemporalBenchAction(answer=r.observation.options[0]))
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
Note: `close()` does not stop a remote server you attached to with `base_url=...`.
|
| 142 |
+
|
| 143 |
+
### Using the Context Manager
|
| 144 |
+
|
| 145 |
+
The client supports context manager usage for automatic connection management:
|
| 146 |
+
|
| 147 |
+
```python
|
| 148 |
+
from client import TemporalBenchAction, TemporalBenchEnvClient
|
| 149 |
+
|
| 150 |
+
with TemporalBenchEnvClient(base_url="http://localhost:8000") as env:
|
| 151 |
+
result = env.reset()
|
| 152 |
+
while not result.done:
|
| 153 |
+
ans = result.observation.options[0]
|
| 154 |
+
result = env.step(TemporalBenchAction(answer=ans))
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
The client uses WebSocket connections for:
|
| 158 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 159 |
+
- **Persistent session**: Server maintains your environment state
|
| 160 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 161 |
+
|
| 162 |
+
### Concurrent WebSocket Sessions
|
| 163 |
+
|
| 164 |
+
The server uses **factory mode** (`create_app(_env_factory, ...)`) so each WebSocket session gets a fresh `TemporalBenchEnvironment`. Tune `max_concurrent_envs` in `server/app.py` as needed.
|
| 165 |
+
|
| 166 |
+
## Development & Testing
|
| 167 |
+
|
| 168 |
+
### Direct environment testing
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
uv sync --extra dev
|
| 172 |
+
uv run pytest tests/
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### Running Locally
|
| 176 |
+
|
| 177 |
+
Run the server locally for development:
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
uvicorn server.app:app --reload
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## Project Structure
|
| 184 |
+
|
| 185 |
+
```
|
| 186 |
+
TemporalBenchEnv/
|
| 187 |
+
├── .dockerignore # Docker build exclusions
|
| 188 |
+
├── __init__.py # Module exports
|
| 189 |
+
├── README.md # This file
|
| 190 |
+
├── openenv.yaml # OpenEnv manifest
|
| 191 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 192 |
+
├── uv.lock # Locked dependencies (generated)
|
| 193 |
+
├── client.py # TemporalBenchEnvClient (alias TemporalbenchenvEnv)
|
| 194 |
+
├── models.py # Action / observation / state re-exports
|
| 195 |
+
├── env/ # Environment, sampler, grading, rewards
|
| 196 |
+
├── data/ # TSQuestion schema + JSON/JSONL loaders
|
| 197 |
+
└── server/
|
| 198 |
+
├── __init__.py # Server module exports
|
| 199 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 200 |
+
└── Dockerfile # Container image definition
|
| 201 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""TemporalBenchEnv — MCQ time-series reasoning for OpenEnv."""
|
| 8 |
+
|
| 9 |
+
from client import TemporalBenchEnvClient, TemporalbenchenvEnv
|
| 10 |
+
from models import (
|
| 11 |
+
TemporalBenchAction,
|
| 12 |
+
TemporalBenchObservation,
|
| 13 |
+
TemporalBenchState,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"TemporalBenchAction",
|
| 18 |
+
"TemporalBenchObservation",
|
| 19 |
+
"TemporalBenchState",
|
| 20 |
+
"TemporalBenchEnvClient",
|
| 21 |
+
"TemporalbenchenvEnv",
|
| 22 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Typed OpenEnv client for TemporalBenchEnv."""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core.client_types import StepResult
|
| 12 |
+
from openenv.core.env_client import EnvClient
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from env.models import (
|
| 16 |
+
TemporalBenchAction,
|
| 17 |
+
TemporalBenchObservation,
|
| 18 |
+
TemporalBenchState,
|
| 19 |
+
)
|
| 20 |
+
except ImportError:
|
| 21 |
+
from TemporalBenchEnv.env.models import (
|
| 22 |
+
TemporalBenchAction,
|
| 23 |
+
TemporalBenchObservation,
|
| 24 |
+
TemporalBenchState,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TemporalBenchEnvClient(
|
| 29 |
+
EnvClient[
|
| 30 |
+
TemporalBenchAction,
|
| 31 |
+
TemporalBenchObservation,
|
| 32 |
+
TemporalBenchState,
|
| 33 |
+
]
|
| 34 |
+
):
|
| 35 |
+
"""WebSocket client for TemporalBench MCQ episodes."""
|
| 36 |
+
|
| 37 |
+
def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]:
|
| 38 |
+
payload: Dict[str, Any] = {"answer": action.answer}
|
| 39 |
+
if action.confidence is not None:
|
| 40 |
+
payload["confidence"] = action.confidence
|
| 41 |
+
if action.reasoning is not None:
|
| 42 |
+
payload["reasoning"] = action.reasoning
|
| 43 |
+
return payload
|
| 44 |
+
|
| 45 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]:
|
| 46 |
+
obs_data = payload.get("observation")
|
| 47 |
+
if not isinstance(obs_data, dict):
|
| 48 |
+
obs_data = payload if isinstance(payload, dict) else {}
|
| 49 |
+
|
| 50 |
+
done = payload.get("done", obs_data.get("done", False))
|
| 51 |
+
reward = payload.get("reward", obs_data.get("reward"))
|
| 52 |
+
|
| 53 |
+
observation = TemporalBenchObservation(
|
| 54 |
+
step_idx=int(obs_data.get("step_idx", 0)),
|
| 55 |
+
steps_remaining=int(obs_data.get("steps_remaining", 0)),
|
| 56 |
+
max_steps=int(obs_data.get("max_steps", 9)),
|
| 57 |
+
question=str(obs_data.get("question", "")),
|
| 58 |
+
options=list(obs_data.get("options", [])),
|
| 59 |
+
task_type=str(obs_data.get("task_type", "")),
|
| 60 |
+
dataset=str(obs_data.get("dataset", "")),
|
| 61 |
+
history=list(obs_data.get("history", [])),
|
| 62 |
+
accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)),
|
| 63 |
+
done=done,
|
| 64 |
+
reward=reward,
|
| 65 |
+
metadata=obs_data.get("metadata", {}),
|
| 66 |
+
)
|
| 67 |
+
return StepResult(observation=observation, reward=reward, done=done)
|
| 68 |
+
|
| 69 |
+
def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState:
|
| 70 |
+
state_data = payload.get("state")
|
| 71 |
+
if not isinstance(state_data, dict):
|
| 72 |
+
state_data = payload if isinstance(payload, dict) else {}
|
| 73 |
+
|
| 74 |
+
return TemporalBenchState(
|
| 75 |
+
episode_id=state_data.get("episode_id"),
|
| 76 |
+
step_count=int(state_data.get("step_count", 0)),
|
| 77 |
+
total_correct=int(state_data.get("total_correct", 0)),
|
| 78 |
+
total_questions=int(state_data.get("total_questions", 9)),
|
| 79 |
+
current_accuracy=float(state_data.get("current_accuracy", 0.0)),
|
| 80 |
+
primary_domain=str(state_data.get("primary_domain", "PSML")),
|
| 81 |
+
per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})),
|
| 82 |
+
total_reward=float(state_data.get("total_reward", 0.0)),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
TemporalbenchenvEnv = TemporalBenchEnvClient
|
data/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Question bank types and loaders for TemporalBenchEnv."""
|
| 2 |
+
|
| 3 |
+
from .question import TSQuestion
|
| 4 |
+
from .loaders import load_question_banks
|
| 5 |
+
|
| 6 |
+
__all__ = ["TSQuestion", "load_question_banks"]
|
data/banks/MIMIC_questions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/banks/PSML_questions.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fbe5a4b7ef698a2cf38cc1bad5e0a49e8ea2e8d5e694fed063992fb9b85202c
|
| 3 |
+
size 31377743
|
data/banks/causal_chambers_questions.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8fecebc4f5dc8db88b83825bd35d0a9b003f16c8df43521d616e928104481163
|
| 3 |
+
size 30021480
|
data/banks/freshretailnet_questions.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:475410d2787c20bf3d908e4f4b67febd78ff9da3f4e2bf993ef8a7c9af31431a
|
| 3 |
+
size 31221095
|
data/banks/manifest.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"source_file": "/Users/yashaswisharma/Desktop/Academia Career/usc-gradschool/Research/AgentX/TS-Agents-Melady/TS-benchmark/task_merged_dev_with_labels_tiers.jsonl",
|
| 3 |
+
"input_rows": 764,
|
| 4 |
+
"total_ts_questions": 2775,
|
| 5 |
+
"per_dataset": {
|
| 6 |
+
"PSML": 750,
|
| 7 |
+
"freshretailnet": 616,
|
| 8 |
+
"MIMIC": 709,
|
| 9 |
+
"causal_chambers": 700
|
| 10 |
+
},
|
| 11 |
+
"openenv_loader": "Set TEMPORALBENCH_QUESTION_BANK_DIR to this directory."
|
| 12 |
+
}
|
data/loaders.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load question banks from JSON or JSONL files."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from .question import TSQuestion
|
| 10 |
+
|
| 11 |
+
# Canonical domain keys used by EpisodeSampler (must match bank files or dataset field)
|
| 12 |
+
DEFAULT_DOMAIN_ORDER = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _parse_records(raw: Any) -> list[dict[str, Any]]:
|
| 16 |
+
if isinstance(raw, list):
|
| 17 |
+
return [x for x in raw if isinstance(x, dict)]
|
| 18 |
+
if isinstance(raw, dict) and "questions" in raw:
|
| 19 |
+
q = raw["questions"]
|
| 20 |
+
if isinstance(q, list):
|
| 21 |
+
return [x for x in q if isinstance(x, dict)]
|
| 22 |
+
raise ValueError("JSON root must be a list of objects or {\"questions\": [...]}")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _record_to_question(obj: dict[str, Any]) -> TSQuestion:
|
| 26 |
+
return TSQuestion.model_validate(obj)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_json_file(path: Path) -> list[TSQuestion]:
|
| 30 |
+
"""Load a single .json file (array or {\"questions\": [...]})."""
|
| 31 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 32 |
+
records = _parse_records(raw)
|
| 33 |
+
return [_record_to_question(r) for r in records]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_jsonl_file(path: Path) -> list[TSQuestion]:
|
| 37 |
+
"""Load newline-delimited JSON; each line must be a full TSQuestion object."""
|
| 38 |
+
out: list[TSQuestion] = []
|
| 39 |
+
for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
|
| 40 |
+
line = line.strip()
|
| 41 |
+
if not line:
|
| 42 |
+
continue
|
| 43 |
+
try:
|
| 44 |
+
obj = json.loads(line)
|
| 45 |
+
except json.JSONDecodeError as e:
|
| 46 |
+
raise ValueError(f"{path}:{line_no}: invalid JSON: {e}") from e
|
| 47 |
+
if not isinstance(obj, dict):
|
| 48 |
+
raise ValueError(f"{path}:{line_no}: expected object per line")
|
| 49 |
+
out.append(_record_to_question(obj))
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_question_banks(
|
| 54 |
+
bank_dir: Path | str | None,
|
| 55 |
+
*,
|
| 56 |
+
domain_order: tuple[str, ...] = DEFAULT_DOMAIN_ORDER,
|
| 57 |
+
explicit_files: list[Path | str] | None = None,
|
| 58 |
+
) -> dict[str, list[TSQuestion]]:
|
| 59 |
+
"""
|
| 60 |
+
Load per-dataset question pools.
|
| 61 |
+
|
| 62 |
+
If ``bank_dir`` is set, loads ``<Dataset>_questions.json`` for each domain in
|
| 63 |
+
``domain_order`` when that file exists, plus any ``*.json`` / ``*.jsonl`` in
|
| 64 |
+
the directory that declare a ``dataset`` field per record (merged lists).
|
| 65 |
+
|
| 66 |
+
If ``explicit_files`` is set, each file is loaded; records are grouped by
|
| 67 |
+
``dataset`` field (required for merged files).
|
| 68 |
+
"""
|
| 69 |
+
pools: dict[str, list[TSQuestion]] = {d: [] for d in domain_order}
|
| 70 |
+
|
| 71 |
+
if explicit_files:
|
| 72 |
+
for fp in explicit_files:
|
| 73 |
+
path = Path(fp)
|
| 74 |
+
items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
|
| 75 |
+
for q in items:
|
| 76 |
+
if q.dataset not in pools:
|
| 77 |
+
pools[q.dataset] = []
|
| 78 |
+
pools[q.dataset].append(q)
|
| 79 |
+
return pools
|
| 80 |
+
|
| 81 |
+
if bank_dir is None:
|
| 82 |
+
return pools
|
| 83 |
+
|
| 84 |
+
root = Path(bank_dir)
|
| 85 |
+
if not root.is_dir():
|
| 86 |
+
raise NotADirectoryError(f"question_bank_path must be a directory: {root}")
|
| 87 |
+
|
| 88 |
+
# Per-dataset convention: PSML_questions.json etc.
|
| 89 |
+
for domain in domain_order:
|
| 90 |
+
candidates = [
|
| 91 |
+
root / f"{domain}_questions.json",
|
| 92 |
+
root / f"{domain.lower()}_questions.json",
|
| 93 |
+
]
|
| 94 |
+
for c in candidates:
|
| 95 |
+
if c.is_file():
|
| 96 |
+
pools[domain].extend(load_json_file(c))
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
# Any extra json/jsonl with dataset on each row (skip per-dataset files + manifests)
|
| 100 |
+
for path in sorted(root.glob("*.json")) + sorted(root.glob("*.jsonl")):
|
| 101 |
+
if path.name in ("manifest.json", "build_manifest.json"):
|
| 102 |
+
continue
|
| 103 |
+
if any(path.name == f"{d}_questions.json" for d in domain_order):
|
| 104 |
+
continue
|
| 105 |
+
if any(path.name == f"{d.lower()}_questions.json" for d in domain_order):
|
| 106 |
+
continue
|
| 107 |
+
items = load_jsonl_file(path) if path.suffix.lower() == ".jsonl" else load_json_file(path)
|
| 108 |
+
for q in items:
|
| 109 |
+
key = q.dataset
|
| 110 |
+
if key not in pools:
|
| 111 |
+
pools[key] = []
|
| 112 |
+
pools[key].append(q)
|
| 113 |
+
|
| 114 |
+
return pools
|
data/question.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Canonical question record for TemporalBench MCQ episodes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TSQuestion(BaseModel):
|
| 11 |
+
"""One MCQ item from a pre-built question bank."""
|
| 12 |
+
|
| 13 |
+
question_id: str = Field(..., description="Unique identifier")
|
| 14 |
+
dataset: str = Field(
|
| 15 |
+
...,
|
| 16 |
+
description="Source dataset, e.g. PSML, freshretailnet, MIMIC, causal_chambers",
|
| 17 |
+
)
|
| 18 |
+
task_type: str = Field(..., description="T1U | T3 | T2_MCQ")
|
| 19 |
+
family: str | None = Field(default=None, description="T3 family, e.g. S1:A")
|
| 20 |
+
prompt: str = Field(..., description="Full question text with context")
|
| 21 |
+
options: list[str] = Field(..., min_length=2, description="Answer choices")
|
| 22 |
+
answer: str = Field(..., description="Ground-truth label (matches one option after normalize)")
|
| 23 |
+
capability_tags: list[str] = Field(default_factory=list)
|
| 24 |
+
difficulty: str | None = Field(default=None)
|
| 25 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
env/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TemporalBenchEnv core logic."""
|
| 2 |
+
|
| 3 |
+
from .config import EnvConfig
|
| 4 |
+
from .temporal_bench_env import TemporalBenchEnvironment
|
| 5 |
+
|
| 6 |
+
__all__ = ["EnvConfig", "TemporalBenchEnvironment"]
|
env/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TemporalBenchEnv configuration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DEFAULT_DOMAINS: tuple[str, ...] = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class EnvConfig:
|
| 13 |
+
"""Episode structure, rewards, and data paths."""
|
| 14 |
+
|
| 15 |
+
num_questions: int = 9
|
| 16 |
+
primary_count: int = 6
|
| 17 |
+
secondary_count: int = 3
|
| 18 |
+
primary_domain: str = "PSML"
|
| 19 |
+
allowed_task_types: tuple[str, ...] = ("T1U", "T3", "T2_MCQ")
|
| 20 |
+
alpha: float = 1.0
|
| 21 |
+
lambda_ep: float = 0.5
|
| 22 |
+
partial_credit_t1u: bool = False
|
| 23 |
+
case_insensitive_match: bool = False
|
| 24 |
+
curriculum_stage: int = 3
|
| 25 |
+
question_bank_path: str | None = None
|
| 26 |
+
seed: int | None = None
|
| 27 |
+
forecasting_enabled: bool = False
|
| 28 |
+
all_domains: tuple[str, ...] = field(default_factory=lambda: DEFAULT_DOMAINS)
|
env/episode_sampler.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sample 9-question episodes from per-dataset pools."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from .config import DEFAULT_DOMAINS, EnvConfig
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from data.question import TSQuestion
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _curriculum_task_types(config: EnvConfig) -> set[str]:
|
| 17 |
+
stage = int(config.curriculum_stage)
|
| 18 |
+
if stage <= 1:
|
| 19 |
+
allowed = {"T1U"}
|
| 20 |
+
elif stage == 2:
|
| 21 |
+
allowed = {"T1U", "T3"}
|
| 22 |
+
else:
|
| 23 |
+
allowed = {"T1U", "T3", "T2_MCQ"}
|
| 24 |
+
return allowed.intersection(set(config.allowed_task_types))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _filter_pool(questions: list[TSQuestion], config: EnvConfig) -> list[TSQuestion]:
|
| 28 |
+
allowed_types = _curriculum_task_types(config)
|
| 29 |
+
return [q for q in questions if q.task_type in allowed_types]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class EpisodeSampler:
|
| 33 |
+
"""Builds stratified multi-domain MCQ episodes."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
pools: dict[str, list[TSQuestion]],
|
| 38 |
+
config: EnvConfig,
|
| 39 |
+
rng: np.random.Generator,
|
| 40 |
+
):
|
| 41 |
+
self._config = config
|
| 42 |
+
self._rng = rng
|
| 43 |
+
self._pools: dict[str, list[TSQuestion]] = {}
|
| 44 |
+
for domain, qs in pools.items():
|
| 45 |
+
filtered = _filter_pool(list(qs), config)
|
| 46 |
+
if filtered:
|
| 47 |
+
self._pools[domain] = filtered
|
| 48 |
+
|
| 49 |
+
def sample_episode(self) -> list[TSQuestion]:
|
| 50 |
+
"""Return shuffled list of ``num_questions`` questions (default 9)."""
|
| 51 |
+
cfg = self._config
|
| 52 |
+
n = cfg.num_questions
|
| 53 |
+
primary_n = cfg.primary_count
|
| 54 |
+
domains = list(cfg.all_domains)
|
| 55 |
+
primary = cfg.primary_domain
|
| 56 |
+
if primary not in domains:
|
| 57 |
+
raise ValueError(f"primary_domain {primary!r} not in all_domains {domains}")
|
| 58 |
+
|
| 59 |
+
others = [d for d in domains if d != primary]
|
| 60 |
+
if len(others) != 3:
|
| 61 |
+
raise ValueError("Expected exactly 3 non-primary domains")
|
| 62 |
+
|
| 63 |
+
for d in domains:
|
| 64 |
+
if d not in self._pools or not self._pools[d]:
|
| 65 |
+
raise ValueError(f"No questions available for dataset {d!r} after filtering")
|
| 66 |
+
|
| 67 |
+
primary_pool = list(self._pools[primary])
|
| 68 |
+
chosen: list[TSQuestion] = []
|
| 69 |
+
|
| 70 |
+
chosen.extend(self._sample_primary_stratified(primary_pool, primary_n))
|
| 71 |
+
|
| 72 |
+
for d in others:
|
| 73 |
+
idx = int(self._rng.integers(0, len(self._pools[d])))
|
| 74 |
+
chosen.append(self._pools[d][idx])
|
| 75 |
+
|
| 76 |
+
if len(chosen) != n:
|
| 77 |
+
raise ValueError(f"Expected {n} questions, got {len(chosen)}")
|
| 78 |
+
|
| 79 |
+
order = self._rng.permutation(n)
|
| 80 |
+
return [chosen[i] for i in order]
|
| 81 |
+
|
| 82 |
+
def _sample_primary_stratified(self, pool: list[TSQuestion], k: int) -> list[TSQuestion]:
|
| 83 |
+
"""Prefer spreading T3 across families; fill with uniform random."""
|
| 84 |
+
rng = self._rng
|
| 85 |
+
if len(pool) < k:
|
| 86 |
+
raise ValueError(f"Primary pool too small: need {k}, have {len(pool)}")
|
| 87 |
+
|
| 88 |
+
# Partition by (task_type, family or "")
|
| 89 |
+
buckets: dict[tuple[str, str], list[TSQuestion]] = defaultdict(list)
|
| 90 |
+
for q in pool:
|
| 91 |
+
fam = q.family or ""
|
| 92 |
+
buckets[(q.task_type, fam)].append(q)
|
| 93 |
+
|
| 94 |
+
# Round-robin order across T3 families first, then other keys
|
| 95 |
+
t3_families = sorted({fam for (tt, fam) in buckets if tt == "T3"})
|
| 96 |
+
rr_keys: list[tuple[str, str]] = [("T3", f) for f in t3_families]
|
| 97 |
+
other_keys = [key for key in buckets if key not in rr_keys]
|
| 98 |
+
rng.shuffle(other_keys)
|
| 99 |
+
key_order = rr_keys + other_keys
|
| 100 |
+
|
| 101 |
+
picks: list[TSQuestion] = []
|
| 102 |
+
while len(picks) < k and key_order:
|
| 103 |
+
progressed = False
|
| 104 |
+
for key in list(key_order):
|
| 105 |
+
if len(picks) >= k:
|
| 106 |
+
break
|
| 107 |
+
b = buckets.get(key, [])
|
| 108 |
+
if not b:
|
| 109 |
+
key_order.remove(key)
|
| 110 |
+
continue
|
| 111 |
+
idx = int(rng.integers(0, len(b)))
|
| 112 |
+
picks.append(b.pop(idx))
|
| 113 |
+
progressed = True
|
| 114 |
+
if not progressed:
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
# Uniform fill from remaining pool elements
|
| 118 |
+
remainder = [q for lst in buckets.values() for q in lst]
|
| 119 |
+
rng.shuffle(remainder)
|
| 120 |
+
for q in remainder:
|
| 121 |
+
if len(picks) >= k:
|
| 122 |
+
break
|
| 123 |
+
picks.append(q)
|
| 124 |
+
|
| 125 |
+
if len(picks) < k:
|
| 126 |
+
chosen_ids = {p.question_id for p in picks}
|
| 127 |
+
pool_rest = [q for q in pool if q.question_id not in chosen_ids]
|
| 128 |
+
need = k - len(picks)
|
| 129 |
+
if len(pool_rest) < need:
|
| 130 |
+
raise ValueError("Primary pool could not supply enough unique questions")
|
| 131 |
+
extra_idx = rng.choice(len(pool_rest), size=need, replace=False)
|
| 132 |
+
for i in extra_idx:
|
| 133 |
+
picks.append(pool_rest[int(i)])
|
| 134 |
+
|
| 135 |
+
return picks[:k]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def default_domains_for_config(config: EnvConfig) -> tuple[str, ...]:
|
| 139 |
+
"""Domains tuple from config (falls back to module default)."""
|
| 140 |
+
return tuple(config.all_domains) if config.all_domains else DEFAULT_DOMAINS
|
env/grading.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MCQ answer grading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from data.question import TSQuestion
|
| 9 |
+
|
| 10 |
+
from .config import EnvConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _normalize(s: str, *, lower: bool) -> str:
|
| 14 |
+
s = str(s).strip()
|
| 15 |
+
return s.lower() if lower else s
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def grade_answer(
|
| 19 |
+
submitted: str,
|
| 20 |
+
question: TSQuestion,
|
| 21 |
+
config: EnvConfig,
|
| 22 |
+
) -> tuple[bool, float]:
|
| 23 |
+
"""
|
| 24 |
+
Returns (is_fully_correct, partial_score in [0,1]).
|
| 25 |
+
|
| 26 |
+
Binary match to ``question.answer``, or equality to an option whose
|
| 27 |
+
normalized text equals the normalized ground-truth answer.
|
| 28 |
+
``partial_credit_t1u`` is reserved for future multi-part T1U grading.
|
| 29 |
+
"""
|
| 30 |
+
lower = config.case_insensitive_match
|
| 31 |
+
truth = _normalize(question.answer, lower=lower)
|
| 32 |
+
pred = _normalize(submitted, lower=lower)
|
| 33 |
+
|
| 34 |
+
if pred == truth:
|
| 35 |
+
return True, 1.0
|
| 36 |
+
for opt in question.options:
|
| 37 |
+
if pred == _normalize(opt, lower=lower) and _normalize(opt, lower=lower) == truth:
|
| 38 |
+
return True, 1.0
|
| 39 |
+
return False, 0.0
|
env/models.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv Pydantic models for TemporalBenchEnv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from openenv.core.env_server.types import Action as _ActionBase
|
| 11 |
+
from openenv.core.env_server.types import Observation as _ObservationBase
|
| 12 |
+
from openenv.core.env_server.types import State as _StateBase
|
| 13 |
+
except ImportError:
|
| 14 |
+
_ActionBase = BaseModel
|
| 15 |
+
_ObservationBase = BaseModel
|
| 16 |
+
_StateBase = BaseModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TemporalBenchAction(_ActionBase):
|
| 20 |
+
"""Agent submits an MCQ answer (optional confidence / reasoning)."""
|
| 21 |
+
|
| 22 |
+
if _ActionBase is BaseModel:
|
| 23 |
+
model_config = ConfigDict(extra="forbid")
|
| 24 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 25 |
+
|
| 26 |
+
answer: str = Field(..., description="MCQ answer label matching an option")
|
| 27 |
+
confidence: float | None = Field(default=None, ge=0.0, le=1.0)
|
| 28 |
+
reasoning: str | None = Field(default=None, description="Optional chain-of-thought")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TemporalBenchObservation(_ObservationBase):
|
| 32 |
+
"""Current question and progress."""
|
| 33 |
+
|
| 34 |
+
if _ObservationBase is BaseModel:
|
| 35 |
+
model_config = ConfigDict(extra="forbid")
|
| 36 |
+
done: bool = Field(default=False)
|
| 37 |
+
reward: float | None = Field(default=None)
|
| 38 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 39 |
+
|
| 40 |
+
step_idx: int = Field(..., ge=0)
|
| 41 |
+
steps_remaining: int = Field(..., ge=0)
|
| 42 |
+
max_steps: int = Field(default=9, ge=1)
|
| 43 |
+
question: str = Field(..., description="Current MCQ prompt")
|
| 44 |
+
options: list[str] = Field(..., description="Answer choices")
|
| 45 |
+
task_type: str = Field(..., description="T1U | T3 | T2_MCQ")
|
| 46 |
+
dataset: str = Field(..., description="Source dataset")
|
| 47 |
+
history: list[dict[str, Any]] = Field(default_factory=list)
|
| 48 |
+
accuracy_so_far: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TemporalBenchState(_StateBase):
|
| 52 |
+
"""Serializable environment state."""
|
| 53 |
+
|
| 54 |
+
if _StateBase is BaseModel:
|
| 55 |
+
model_config = ConfigDict(extra="allow")
|
| 56 |
+
episode_id: str | None = Field(default=None)
|
| 57 |
+
step_count: int = Field(default=0, ge=0)
|
| 58 |
+
|
| 59 |
+
total_correct: int = Field(default=0, ge=0)
|
| 60 |
+
total_questions: int = Field(default=9, ge=0)
|
| 61 |
+
current_accuracy: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 62 |
+
primary_domain: str = Field(default="PSML")
|
| 63 |
+
per_task_type_accuracy: dict[str, float] = Field(default_factory=dict)
|
| 64 |
+
total_reward: float = Field(default=0.0)
|
env/reward.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward helpers for TemporalBenchEnv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def compute_mcq_reward(correctness_score: float, alpha: float = 1.0) -> float:
|
| 11 |
+
"""Per-step reward: alpha * correctness (0..1)."""
|
| 12 |
+
return float(alpha) * float(correctness_score)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compute_forecasting_reward(
|
| 16 |
+
predicted: np.ndarray,
|
| 17 |
+
ground_truth: np.ndarray,
|
| 18 |
+
metric: str = "mse",
|
| 19 |
+
**kwargs: Any,
|
| 20 |
+
) -> float:
|
| 21 |
+
del predicted, ground_truth, metric, kwargs
|
| 22 |
+
raise NotImplementedError("Forecasting reward not yet implemented")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def compute_episode_bonus(
|
| 26 |
+
total_correct: int,
|
| 27 |
+
num_questions: int,
|
| 28 |
+
domain_coverage: dict[str, int],
|
| 29 |
+
*,
|
| 30 |
+
all_domains: tuple[str, ...],
|
| 31 |
+
lambda_ep: float = 0.5,
|
| 32 |
+
) -> float:
|
| 33 |
+
"""
|
| 34 |
+
Episode bonus: lambda_ep * episode_accuracy * domain_coverage_bonus.
|
| 35 |
+
|
| 36 |
+
``domain_coverage`` maps dataset name -> count of *correct* answers in the
|
| 37 |
+
episode for that dataset. Bonus uses 1.0 if every domain in ``all_domains``
|
| 38 |
+
has at least one correct answer, else 0.8.
|
| 39 |
+
"""
|
| 40 |
+
accuracy = total_correct / num_questions if num_questions > 0 else 0.0
|
| 41 |
+
counts = [int(domain_coverage.get(d, 0)) for d in all_domains]
|
| 42 |
+
all_covered = all(c > 0 for c in counts)
|
| 43 |
+
coverage_bonus = 1.0 if all_covered else 0.8
|
| 44 |
+
return float(lambda_ep) * float(accuracy) * coverage_bonus
|
env/temporal_bench_env.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core OpenEnv environment for TemporalBench MCQ episodes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import uuid
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from dataclasses import replace
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from data.loaders import load_question_banks
|
| 13 |
+
from data.question import TSQuestion
|
| 14 |
+
|
| 15 |
+
from .config import EnvConfig
|
| 16 |
+
from .episode_sampler import EpisodeSampler
|
| 17 |
+
from .grading import grade_answer
|
| 18 |
+
from .models import TemporalBenchAction, TemporalBenchObservation, TemporalBenchState
|
| 19 |
+
from .reward import compute_episode_bonus, compute_mcq_reward
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from openenv.core.env_server.interfaces import Environment
|
| 23 |
+
except ImportError:
|
| 24 |
+
from abc import ABC, abstractmethod
|
| 25 |
+
from typing import Generic, TypeVar
|
| 26 |
+
|
| 27 |
+
ActT = TypeVar("ActT")
|
| 28 |
+
ObsT = TypeVar("ObsT")
|
| 29 |
+
StateT = TypeVar("StateT")
|
| 30 |
+
|
| 31 |
+
class Environment(ABC, Generic[ActT, ObsT, StateT]):
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def reset(self, seed=None, episode_id=None, **kwargs): ...
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def step(self, action, timeout_s=None, **kwargs): ...
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def state(self): ...
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TemporalBenchEnvironment(
|
| 44 |
+
Environment[TemporalBenchAction, TemporalBenchObservation, TemporalBenchState]
|
| 45 |
+
):
|
| 46 |
+
"""Multi-step MCQ environment over a pre-built TemporalBench question bank."""
|
| 47 |
+
|
| 48 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 49 |
+
|
| 50 |
+
def __init__(self, config: Optional[EnvConfig] = None, **kwargs: Any):
|
| 51 |
+
super().__init__(**kwargs)
|
| 52 |
+
self._config = config or EnvConfig()
|
| 53 |
+
seed = self._config.seed
|
| 54 |
+
self._rng = np.random.default_rng(seed)
|
| 55 |
+
|
| 56 |
+
self._banks = load_question_banks(self._config.question_bank_path)
|
| 57 |
+
self._sampler = EpisodeSampler(self._banks, self._config, self._rng)
|
| 58 |
+
|
| 59 |
+
self._episode_id: Optional[str] = None
|
| 60 |
+
self._questions: list[TSQuestion] = []
|
| 61 |
+
self._answered: int = 0
|
| 62 |
+
self._history: list[dict[str, Any]] = []
|
| 63 |
+
self._done: bool = False
|
| 64 |
+
self._total_correct: int = 0
|
| 65 |
+
self._total_reward: float = 0.0
|
| 66 |
+
self._domain_correct: dict[str, int] = defaultdict(int)
|
| 67 |
+
self._task_correct: dict[str, int] = defaultdict(int)
|
| 68 |
+
self._task_total: dict[str, int] = defaultdict(int)
|
| 69 |
+
self._last_metadata: dict[str, Any] = {}
|
| 70 |
+
|
| 71 |
+
def _accuracy_so_far(self) -> float:
|
| 72 |
+
if self._answered == 0:
|
| 73 |
+
return 0.0
|
| 74 |
+
return self._total_correct / self._answered
|
| 75 |
+
|
| 76 |
+
def _per_task_accuracy(self) -> dict[str, float]:
|
| 77 |
+
out: dict[str, float] = {}
|
| 78 |
+
for k, tot in self._task_total.items():
|
| 79 |
+
out[k] = (self._task_correct[k] / tot) if tot else 0.0
|
| 80 |
+
return out
|
| 81 |
+
|
| 82 |
+
def _build_observation(
|
| 83 |
+
self,
|
| 84 |
+
*,
|
| 85 |
+
reward: float | None,
|
| 86 |
+
done: bool,
|
| 87 |
+
) -> TemporalBenchObservation:
|
| 88 |
+
n = self._config.num_questions
|
| 89 |
+
if done or self._answered >= n:
|
| 90 |
+
return TemporalBenchObservation(
|
| 91 |
+
step_idx=self._answered,
|
| 92 |
+
steps_remaining=0,
|
| 93 |
+
max_steps=n,
|
| 94 |
+
question="",
|
| 95 |
+
options=[],
|
| 96 |
+
task_type="",
|
| 97 |
+
dataset="",
|
| 98 |
+
history=list(self._history),
|
| 99 |
+
accuracy_so_far=self._accuracy_so_far(),
|
| 100 |
+
done=True,
|
| 101 |
+
reward=reward,
|
| 102 |
+
metadata=dict(self._last_metadata),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
q = self._questions[self._answered]
|
| 106 |
+
steps_remaining = n - self._answered
|
| 107 |
+
return TemporalBenchObservation(
|
| 108 |
+
step_idx=self._answered,
|
| 109 |
+
steps_remaining=steps_remaining,
|
| 110 |
+
max_steps=n,
|
| 111 |
+
question=q.prompt,
|
| 112 |
+
options=list(q.options),
|
| 113 |
+
task_type=q.task_type,
|
| 114 |
+
dataset=q.dataset,
|
| 115 |
+
history=list(self._history),
|
| 116 |
+
accuracy_so_far=self._accuracy_so_far(),
|
| 117 |
+
done=False,
|
| 118 |
+
reward=reward,
|
| 119 |
+
metadata=dict(self._last_metadata),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def reset(
|
| 123 |
+
self,
|
| 124 |
+
seed: Optional[int] = None,
|
| 125 |
+
episode_id: Optional[str] = None,
|
| 126 |
+
**kwargs: Any,
|
| 127 |
+
) -> TemporalBenchObservation:
|
| 128 |
+
curriculum_kw = kwargs.pop("curriculum_stage", None)
|
| 129 |
+
|
| 130 |
+
if seed is not None:
|
| 131 |
+
self._rng = np.random.default_rng(seed)
|
| 132 |
+
|
| 133 |
+
cfg = self._config
|
| 134 |
+
if curriculum_kw is not None:
|
| 135 |
+
cfg = replace(self._config, curriculum_stage=int(curriculum_kw))
|
| 136 |
+
|
| 137 |
+
self._sampler = EpisodeSampler(self._banks, cfg, self._rng)
|
| 138 |
+
|
| 139 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 140 |
+
self._questions = self._sampler.sample_episode()
|
| 141 |
+
self._answered = 0
|
| 142 |
+
self._history = []
|
| 143 |
+
self._done = False
|
| 144 |
+
self._total_correct = 0
|
| 145 |
+
self._total_reward = 0.0
|
| 146 |
+
self._domain_correct = defaultdict(int)
|
| 147 |
+
self._task_correct = defaultdict(int)
|
| 148 |
+
self._task_total = defaultdict(int)
|
| 149 |
+
self._last_metadata = {}
|
| 150 |
+
|
| 151 |
+
return self._build_observation(reward=0.0, done=False)
|
| 152 |
+
|
| 153 |
+
def step(
|
| 154 |
+
self,
|
| 155 |
+
action: TemporalBenchAction,
|
| 156 |
+
timeout_s: Optional[float] = None,
|
| 157 |
+
**kwargs: Any,
|
| 158 |
+
) -> TemporalBenchObservation:
|
| 159 |
+
del timeout_s, kwargs
|
| 160 |
+
if self._done:
|
| 161 |
+
self._last_metadata = {"info": "Episode already done."}
|
| 162 |
+
return self._build_observation(reward=0.0, done=True)
|
| 163 |
+
|
| 164 |
+
self._last_metadata = {}
|
| 165 |
+
|
| 166 |
+
n = self._config.num_questions
|
| 167 |
+
if self._answered >= n:
|
| 168 |
+
self._done = True
|
| 169 |
+
self._last_metadata = {"info": "Episode already complete."}
|
| 170 |
+
return self._build_observation(reward=0.0, done=True)
|
| 171 |
+
|
| 172 |
+
q = self._questions[self._answered]
|
| 173 |
+
if not str(action.answer).strip():
|
| 174 |
+
self._last_metadata = {"error": "answer must be a non-empty string."}
|
| 175 |
+
return self._build_observation(reward=0.0, done=False)
|
| 176 |
+
|
| 177 |
+
fully_correct, score = grade_answer(action.answer, q, self._config)
|
| 178 |
+
r_step = compute_mcq_reward(score, alpha=self._config.alpha)
|
| 179 |
+
|
| 180 |
+
self._history.append(
|
| 181 |
+
{
|
| 182 |
+
"question_id": q.question_id,
|
| 183 |
+
"dataset": q.dataset,
|
| 184 |
+
"task_type": q.task_type,
|
| 185 |
+
"submitted": action.answer,
|
| 186 |
+
"correct": fully_correct,
|
| 187 |
+
"reward": r_step,
|
| 188 |
+
}
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
self._task_total[q.task_type] += 1
|
| 192 |
+
if fully_correct:
|
| 193 |
+
self._total_correct += 1
|
| 194 |
+
self._domain_correct[q.dataset] += 1
|
| 195 |
+
self._task_correct[q.task_type] += 1
|
| 196 |
+
|
| 197 |
+
self._answered += 1
|
| 198 |
+
total_reward_this_step = r_step
|
| 199 |
+
|
| 200 |
+
if self._answered >= n:
|
| 201 |
+
bonus = compute_episode_bonus(
|
| 202 |
+
self._total_correct,
|
| 203 |
+
n,
|
| 204 |
+
dict(self._domain_correct),
|
| 205 |
+
all_domains=tuple(self._config.all_domains),
|
| 206 |
+
lambda_ep=self._config.lambda_ep,
|
| 207 |
+
)
|
| 208 |
+
total_reward_this_step = r_step + bonus
|
| 209 |
+
self._done = True
|
| 210 |
+
self._last_metadata = {
|
| 211 |
+
"episode_bonus": bonus,
|
| 212 |
+
"domain_correct_counts": dict(self._domain_correct),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
self._total_reward += total_reward_this_step
|
| 216 |
+
|
| 217 |
+
return self._build_observation(
|
| 218 |
+
reward=total_reward_this_step,
|
| 219 |
+
done=self._done,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def state(self) -> TemporalBenchState:
|
| 224 |
+
return TemporalBenchState(
|
| 225 |
+
episode_id=self._episode_id,
|
| 226 |
+
step_count=self._answered,
|
| 227 |
+
total_correct=self._total_correct,
|
| 228 |
+
total_questions=self._config.num_questions,
|
| 229 |
+
current_accuracy=self._accuracy_so_far(),
|
| 230 |
+
primary_domain=self._config.primary_domain,
|
| 231 |
+
per_task_type_accuracy=self._per_task_accuracy(),
|
| 232 |
+
total_reward=self._total_reward,
|
| 233 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""OpenEnv models at package root for schema validation and server imports."""
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from env.models import (
|
| 11 |
+
TemporalBenchAction,
|
| 12 |
+
TemporalBenchObservation,
|
| 13 |
+
TemporalBenchState,
|
| 14 |
+
)
|
| 15 |
+
except ImportError:
|
| 16 |
+
from .env.models import (
|
| 17 |
+
TemporalBenchAction,
|
| 18 |
+
TemporalBenchObservation,
|
| 19 |
+
TemporalBenchState,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"TemporalBenchAction",
|
| 24 |
+
"TemporalBenchObservation",
|
| 25 |
+
"TemporalBenchState",
|
| 26 |
+
]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: TemporalBenchEnv
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
openenv_TemporalBenchEnv.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-TemporalBenchEnv
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: TemporalBench MCQ environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.0
|
| 7 |
+
Requires-Dist: numpy>=1.24.0
|
| 8 |
+
Requires-Dist: pydantic>=2.0.0
|
| 9 |
+
Provides-Extra: dev
|
| 10 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 11 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_TemporalBenchEnv.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
client.py
|
| 3 |
+
models.py
|
| 4 |
+
pyproject.toml
|
| 5 |
+
data/__init__.py
|
| 6 |
+
data/loaders.py
|
| 7 |
+
data/question.py
|
| 8 |
+
env/__init__.py
|
| 9 |
+
env/config.py
|
| 10 |
+
env/episode_sampler.py
|
| 11 |
+
env/grading.py
|
| 12 |
+
env/models.py
|
| 13 |
+
env/reward.py
|
| 14 |
+
env/temporal_bench_env.py
|
| 15 |
+
openenv_TemporalBenchEnv.egg-info/PKG-INFO
|
| 16 |
+
openenv_TemporalBenchEnv.egg-info/SOURCES.txt
|
| 17 |
+
openenv_TemporalBenchEnv.egg-info/dependency_links.txt
|
| 18 |
+
openenv_TemporalBenchEnv.egg-info/entry_points.txt
|
| 19 |
+
openenv_TemporalBenchEnv.egg-info/requires.txt
|
| 20 |
+
openenv_TemporalBenchEnv.egg-info/top_level.txt
|
| 21 |
+
server/__init__.py
|
| 22 |
+
server/app.py
|
| 23 |
+
tests/test_env.py
|
| 24 |
+
tests/test_loaders.py
|
| 25 |
+
tests/test_reward.py
|
| 26 |
+
tests/test_sampler.py
|
openenv_TemporalBenchEnv.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_TemporalBenchEnv.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = server.app:main
|
openenv_TemporalBenchEnv.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
|
| 5 |
+
[dev]
|
| 6 |
+
pytest>=8.0.0
|
| 7 |
+
pytest-cov>=4.0.0
|
openenv_TemporalBenchEnv.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
client
|
| 2 |
+
data
|
| 3 |
+
env
|
| 4 |
+
models
|
| 5 |
+
server
|
pyproject.toml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-TemporalBenchEnv"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "TemporalBench MCQ environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"openenv-core[core]>=0.2.0",
|
| 18 |
+
"numpy>=1.24.0",
|
| 19 |
+
"pydantic>=2.0.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.optional-dependencies]
|
| 23 |
+
dev = [
|
| 24 |
+
"pytest>=8.0.0",
|
| 25 |
+
"pytest-cov>=4.0.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.scripts]
|
| 29 |
+
server = "server.app:main"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools]
|
| 32 |
+
include-package-data = true
|
| 33 |
+
py-modules = ["client", "models"]
|
| 34 |
+
|
| 35 |
+
[tool.setuptools.packages.find]
|
| 36 |
+
where = ["."]
|
| 37 |
+
include = ["env*", "data*", "server*"]
|
| 38 |
+
|
| 39 |
+
[tool.pytest.ini_options]
|
| 40 |
+
pythonpath = ["."]
|
| 41 |
+
testpaths = ["tests"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""TemporalBenchEnv environment server components."""
|
| 8 |
+
|
| 9 |
+
from env.temporal_bench_env import TemporalBenchEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["TemporalBenchEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""FastAPI application for TemporalBenchEnv."""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from env.config import EnvConfig
|
| 14 |
+
from env.temporal_bench_env import TemporalBenchEnvironment
|
| 15 |
+
from models import TemporalBenchAction, TemporalBenchObservation
|
| 16 |
+
except ImportError:
|
| 17 |
+
from ..env.config import EnvConfig
|
| 18 |
+
from ..env.temporal_bench_env import TemporalBenchEnvironment
|
| 19 |
+
from ..models import TemporalBenchAction, TemporalBenchObservation
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from openenv.core.env_server import create_app
|
| 23 |
+
except ImportError:
|
| 24 |
+
create_app = None # type: ignore
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _env_factory():
|
| 28 |
+
"""Create a fresh environment instance per WebSocket session."""
|
| 29 |
+
bank_dir = os.environ.get("TEMPORALBENCH_QUESTION_BANK_DIR")
|
| 30 |
+
if not bank_dir:
|
| 31 |
+
default = Path(__file__).resolve().parents[1] / "tests" / "fixtures" / "banks"
|
| 32 |
+
if default.is_dir():
|
| 33 |
+
bank_dir = str(default)
|
| 34 |
+
cfg = EnvConfig(question_bank_path=bank_dir) if bank_dir else EnvConfig()
|
| 35 |
+
return TemporalBenchEnvironment(config=cfg)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if create_app is not None:
|
| 39 |
+
app = create_app(
|
| 40 |
+
_env_factory,
|
| 41 |
+
TemporalBenchAction,
|
| 42 |
+
TemporalBenchObservation,
|
| 43 |
+
env_name="temporal-bench-env",
|
| 44 |
+
max_concurrent_envs=64,
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
from fastapi import FastAPI
|
| 48 |
+
|
| 49 |
+
app = FastAPI(title="temporal-bench-env")
|
| 50 |
+
app.get("/health")(lambda: {"status": "ok"})
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main(host: str | None = None, port: int | None = None) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Entry point for `uv run server` and OpenEnv multi-mode validation.
|
| 56 |
+
|
| 57 |
+
OpenEnv's validator does a naive substring check for ``main()`` in this
|
| 58 |
+
file, so the ``if __name__ == "__main__"`` block must call ``main()`` with
|
| 59 |
+
no arguments; CLI flags are parsed here via ``parse_known_args``.
|
| 60 |
+
"""
|
| 61 |
+
import argparse
|
| 62 |
+
|
| 63 |
+
import uvicorn
|
| 64 |
+
|
| 65 |
+
if host is None or port is None:
|
| 66 |
+
parser = argparse.ArgumentParser()
|
| 67 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 68 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 69 |
+
ns, _ = parser.parse_known_args()
|
| 70 |
+
if host is None:
|
| 71 |
+
host = ns.host
|
| 72 |
+
if port is None:
|
| 73 |
+
port = ns.port
|
| 74 |
+
|
| 75 |
+
uvicorn.run(app, host=host, port=port)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
FIXTURE_BANK_DIR = Path(__file__).resolve().parent / "fixtures" / "banks"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def fixture_bank_dir() -> Path:
|
| 10 |
+
assert FIXTURE_BANK_DIR.is_dir()
|
| 11 |
+
return FIXTURE_BANK_DIR
|
tests/fixtures/banks/MIMIC_questions.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{"question_id": "mim-1", "dataset": "MIMIC", "task_type": "T1U", "family": null, "prompt": "M1", "options": ["0", "1"], "answer": "0", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 3 |
+
{"question_id": "mim-2", "dataset": "MIMIC", "task_type": "T2_MCQ", "family": null, "prompt": "M2", "options": ["0", "1"], "answer": "1", "capability_tags": [], "difficulty": null, "metadata": {}}
|
| 4 |
+
]
|
tests/fixtures/banks/PSML_questions.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{"question_id": "psml-t3-s1a-1", "dataset": "PSML", "task_type": "T3", "family": "S1:A", "prompt": "Q1", "options": ["a", "b"], "answer": "a", "capability_tags": ["C1"], "difficulty": "easy", "metadata": {}},
|
| 3 |
+
{"question_id": "psml-t3-s2b-1", "dataset": "PSML", "task_type": "T3", "family": "S2:B", "prompt": "Q2", "options": ["a", "b"], "answer": "b", "capability_tags": ["C2"], "difficulty": "easy", "metadata": {}},
|
| 4 |
+
{"question_id": "psml-t3-s3c-1", "dataset": "PSML", "task_type": "T3", "family": "S3:C", "prompt": "Q3", "options": ["x", "y"], "answer": "x", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 5 |
+
{"question_id": "psml-t1u-1", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q4 trend?", "options": ["upward", "downward"], "answer": "upward", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 6 |
+
{"question_id": "psml-t2-1", "dataset": "PSML", "task_type": "T2_MCQ", "family": null, "prompt": "Q5", "options": ["m1", "m2"], "answer": "m1", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 7 |
+
{"question_id": "psml-t3-s4d-1", "dataset": "PSML", "task_type": "T3", "family": "S4:D", "prompt": "Q6", "options": ["p", "q"], "answer": "p", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 8 |
+
{"question_id": "psml-t1u-2", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7", "options": ["yes", "no"], "answer": "yes", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 9 |
+
{"question_id": "psml-t1u-3", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7b", "options": ["yes", "no"], "answer": "no", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 10 |
+
{"question_id": "psml-t1u-4", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7c", "options": ["u", "d"], "answer": "u", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 11 |
+
{"question_id": "psml-t1u-5", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7d", "options": ["u", "d"], "answer": "d", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 12 |
+
{"question_id": "psml-t1u-6", "dataset": "PSML", "task_type": "T1U", "family": null, "prompt": "Q7e", "options": ["e", "f"], "answer": "e", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 13 |
+
{"question_id": "psml-t2-2", "dataset": "PSML", "task_type": "T2_MCQ", "family": null, "prompt": "Q8", "options": ["h", "l"], "answer": "h", "capability_tags": [], "difficulty": null, "metadata": {}}
|
| 14 |
+
]
|
tests/fixtures/banks/causal_chambers_questions.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{"question_id": "cc-1", "dataset": "causal_chambers", "task_type": "T3", "family": "S5:E", "prompt": "CC1", "options": ["low", "high"], "answer": "low", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 3 |
+
{"question_id": "cc-2", "dataset": "causal_chambers", "task_type": "T1U", "family": null, "prompt": "CC2", "options": ["low", "high"], "answer": "high", "capability_tags": [], "difficulty": null, "metadata": {}}
|
| 4 |
+
]
|
tests/fixtures/banks/freshretailnet_questions.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{"question_id": "frn-1", "dataset": "freshretailnet", "task_type": "T1U", "family": null, "prompt": "FRN1", "options": ["A", "B"], "answer": "A", "capability_tags": [], "difficulty": null, "metadata": {}},
|
| 3 |
+
{"question_id": "frn-2", "dataset": "freshretailnet", "task_type": "T3", "family": "S1:A", "prompt": "FRN2", "options": ["A", "B"], "answer": "B", "capability_tags": [], "difficulty": null, "metadata": {}}
|
| 4 |
+
]
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from env.config import EnvConfig
|
| 6 |
+
from env.temporal_bench_env import TemporalBenchEnvironment
|
| 7 |
+
from models import TemporalBenchAction
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _bank_dir() -> Path:
|
| 11 |
+
return Path(__file__).resolve().parent / "fixtures" / "banks"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_full_episode_all_correct():
|
| 15 |
+
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0)
|
| 16 |
+
env = TemporalBenchEnvironment(config=cfg)
|
| 17 |
+
obs0 = env.reset(seed=123)
|
| 18 |
+
assert not obs0.done
|
| 19 |
+
assert obs0.step_idx == 0
|
| 20 |
+
assert len(obs0.options) >= 2
|
| 21 |
+
|
| 22 |
+
total_r = 0.0
|
| 23 |
+
obs = obs0
|
| 24 |
+
for i in range(cfg.num_questions):
|
| 25 |
+
cur = env._questions[env._answered] # noqa: SLF001
|
| 26 |
+
obs = env.step(TemporalBenchAction(answer=cur.answer))
|
| 27 |
+
total_r += obs.reward or 0.0
|
| 28 |
+
if i < cfg.num_questions - 1:
|
| 29 |
+
assert not obs.done
|
| 30 |
+
assert obs.step_idx == i + 1
|
| 31 |
+
else:
|
| 32 |
+
assert obs.done
|
| 33 |
+
assert obs.step_idx == cfg.num_questions
|
| 34 |
+
|
| 35 |
+
st = env.state
|
| 36 |
+
assert st.total_correct == cfg.num_questions
|
| 37 |
+
assert st.step_count == cfg.num_questions
|
| 38 |
+
assert st.total_reward == pytest.approx(total_r)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_episode_with_one_wrong_answer():
|
| 42 |
+
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0)
|
| 43 |
+
env = TemporalBenchEnvironment(config=cfg)
|
| 44 |
+
env.reset(seed=0)
|
| 45 |
+
for i in range(cfg.num_questions):
|
| 46 |
+
cur = env._questions[env._answered] # noqa: SLF001
|
| 47 |
+
wrong = next(o for o in cur.options if o != cur.answer)
|
| 48 |
+
act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer)
|
| 49 |
+
obs = env.step(act)
|
| 50 |
+
if i == cfg.num_questions - 1:
|
| 51 |
+
assert obs.done
|
| 52 |
+
meta = obs.metadata or {}
|
| 53 |
+
assert "episode_bonus" in meta
|
| 54 |
+
bonus = float(meta["episode_bonus"])
|
| 55 |
+
lo = 0.5 * (8 / 9) * 0.8
|
| 56 |
+
hi = 0.5 * (8 / 9) * 1.0
|
| 57 |
+
assert lo - 1e-9 <= bonus <= hi + 1e-9
|
| 58 |
+
assert env.state.total_correct == 8
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_double_step_after_done():
|
| 62 |
+
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1)
|
| 63 |
+
env = TemporalBenchEnvironment(config=cfg)
|
| 64 |
+
env.reset(seed=1)
|
| 65 |
+
for _ in range(cfg.num_questions):
|
| 66 |
+
cur = env._questions[env._answered] # noqa: SLF001
|
| 67 |
+
env.step(TemporalBenchAction(answer=cur.answer))
|
| 68 |
+
obs = env.step(TemporalBenchAction(answer="x"))
|
| 69 |
+
assert obs.done
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_empty_answer_no_advance():
|
| 73 |
+
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2)
|
| 74 |
+
env = TemporalBenchEnvironment(config=cfg)
|
| 75 |
+
env.reset(seed=2)
|
| 76 |
+
before = env._answered # noqa: SLF001
|
| 77 |
+
obs = env.step(TemporalBenchAction(answer=" "))
|
| 78 |
+
assert env._answered == before # noqa: SLF001
|
| 79 |
+
assert "error" in (obs.metadata or {})
|
tests/test_loaders.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from data.loaders import load_json_file, load_question_banks
|
| 6 |
+
from data.question import TSQuestion
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_load_json_file_roundtrip(fixture_bank_dir: Path):
|
| 10 |
+
path = fixture_bank_dir / "PSML_questions.json"
|
| 11 |
+
qs = load_json_file(path)
|
| 12 |
+
assert len(qs) >= 6
|
| 13 |
+
assert all(isinstance(q, TSQuestion) for q in qs)
|
| 14 |
+
assert qs[0].dataset == "PSML"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_load_question_banks_directory(fixture_bank_dir: Path):
|
| 18 |
+
pools = load_question_banks(str(fixture_bank_dir))
|
| 19 |
+
for d in ("PSML", "freshretailnet", "MIMIC", "causal_chambers"):
|
| 20 |
+
assert d in pools
|
| 21 |
+
assert len(pools[d]) >= 1
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_load_question_banks_rejects_non_dir():
|
| 25 |
+
with pytest.raises(NotADirectoryError):
|
| 26 |
+
load_question_banks("/nonexistent/path/that/is/not/a/dir")
|
tests/test_reward.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from env.reward import compute_episode_bonus, compute_forecasting_reward, compute_mcq_reward
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_compute_mcq_reward():
|
| 7 |
+
assert compute_mcq_reward(1.0, alpha=2.0) == 2.0
|
| 8 |
+
assert compute_mcq_reward(0.0) == 0.0
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_compute_episode_bonus_coverage():
|
| 12 |
+
domains = ("PSML", "freshretailnet", "MIMIC", "causal_chambers")
|
| 13 |
+
cov = {"PSML": 1, "freshretailnet": 0, "MIMIC": 1, "causal_chambers": 1}
|
| 14 |
+
b = compute_episode_bonus(3, 9, cov, all_domains=domains, lambda_ep=0.5)
|
| 15 |
+
assert b == 0.5 * (3 / 9) * 0.8
|
| 16 |
+
|
| 17 |
+
cov2 = {"PSML": 1, "freshretailnet": 1, "MIMIC": 1, "causal_chambers": 1}
|
| 18 |
+
b2 = compute_episode_bonus(4, 9, cov2, all_domains=domains, lambda_ep=0.5)
|
| 19 |
+
assert b2 == 0.5 * (4 / 9) * 1.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_forecasting_reward_stub():
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
with pytest.raises(NotImplementedError):
|
| 26 |
+
compute_forecasting_reward(np.array([1.0]), np.array([1.0]))
|
tests/test_sampler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from data.loaders import load_question_banks
|
| 4 |
+
from env.config import EnvConfig
|
| 5 |
+
from env.episode_sampler import EpisodeSampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_sample_episode_shape_and_domains(fixture_bank_dir):
|
| 9 |
+
banks = load_question_banks(str(fixture_bank_dir))
|
| 10 |
+
cfg = EnvConfig(
|
| 11 |
+
question_bank_path=str(fixture_bank_dir),
|
| 12 |
+
primary_domain="PSML",
|
| 13 |
+
seed=42,
|
| 14 |
+
)
|
| 15 |
+
rng = np.random.default_rng(42)
|
| 16 |
+
sampler = EpisodeSampler(banks, cfg, rng)
|
| 17 |
+
ep = sampler.sample_episode()
|
| 18 |
+
assert len(ep) == cfg.num_questions
|
| 19 |
+
datasets = {q.dataset for q in ep}
|
| 20 |
+
assert datasets == set(cfg.all_domains)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_curriculum_stage_filters_task_types(fixture_bank_dir):
|
| 24 |
+
banks = load_question_banks(str(fixture_bank_dir))
|
| 25 |
+
cfg = EnvConfig(
|
| 26 |
+
question_bank_path=str(fixture_bank_dir),
|
| 27 |
+
curriculum_stage=1,
|
| 28 |
+
seed=0,
|
| 29 |
+
)
|
| 30 |
+
rng = np.random.default_rng(0)
|
| 31 |
+
sampler = EpisodeSampler(banks, cfg, rng)
|
| 32 |
+
for _ in range(5):
|
| 33 |
+
ep = sampler.sample_episode()
|
| 34 |
+
assert all(q.task_type == "T1U" for q in ep)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|