visheshrathi commited on
Commit
f89b1ac
·
verified ·
1 Parent(s): 8afce53

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim AS builder
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1
5
+
6
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
7
+
8
+ WORKDIR /app
9
+
10
+ COPY pyproject.toml uv.lock README.md ./
11
+ RUN uv sync --frozen --no-install-project --no-dev
12
+
13
+ COPY __init__.py client.py env_loader.py inference.py models.py openenv.yaml ./
14
+ COPY server ./server
15
+ COPY data ./data
16
+ RUN uv sync --frozen --no-dev
17
+
18
+ FROM python:3.12-slim
19
+
20
+ ENV PYTHONDONTWRITEBYTECODE=1 \
21
+ PYTHONUNBUFFERED=1 \
22
+ HOST=0.0.0.0 \
23
+ PORT=7860 \
24
+ PATH="/app/.venv/bin:$PATH"
25
+
26
+ WORKDIR /app
27
+
28
+ RUN useradd -m appuser
29
+ COPY --from=builder --chown=appuser:appuser /app /app
30
+
31
+ RUN mkdir -p /app/workspace && chown -R appuser:appuser /app
32
+
33
+ USER appuser
34
+
35
+ HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
36
+ CMD python -c "import os, urllib.request; urllib.request.urlopen('http://127.0.0.1:' + os.getenv('PORT', '7860') + '/health')" || exit 1
37
+
38
+ EXPOSE 7860
39
+
40
+ ENV ENABLE_WEB_INTERFACE=true
41
+ CMD ["python", "-m", "server"]
README.md CHANGED
@@ -1,10 +1,252 @@
1
  ---
2
- title: Dataops Env
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: docker
7
- pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DataOpsEnv
3
+ emoji: 🧩
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
8
+ short_description: OpenEnv DataOps — SQLite, ETL repair, three graded tasks.
9
+ tags:
10
+ - openenv
11
+ base_path: /web
12
  ---
13
 
14
+ # DataOpsEnv
15
+
16
+ [Overview](#environment-description-and-motivation) · [Tasks](#tasks-descriptions-and-expected-difficulty) · [Setup and run](#setup-and-usage) · [Baseline scores](#baseline-scores) · [Hugging Face Spaces](#hugging-face-spaces) · [HTTP API](#api-reference) · [Tests](#tests)
17
+
18
+ ## Environment description and motivation
19
+
20
+ **DataOpsEnv** is an OpenEnv-compliant benchmark in which an agent performs data-engineering work: inspecting a small **SQLite** warehouse, **repairing Python ETL scripts**, and completing an **end-to-end reporting incident** (extract data, fix a formatter, send a mock email). Episodes are **seeded** (`reset` may include `seed`) so scenarios are **reproducible**; each HTTP session receives an **isolated workspace and database**.
21
+
22
+ Many agent benchmarks are game-like or shallow. Data cleaning, script debugging, and stakeholder communication reflect **real workflows**; this environment exercises multi-step tool use, constraint respect, and verifiable outcomes rather than single-shot question answering.
23
+
24
+ **Implementation:** FastAPI (`server/app.py`), environment logic (`server/dataops_env_environment.py`), terminal graders (`server/grading.py`), scenario definitions (`server/task_specs.py`, `data/init_db.py`), Pydantic types (`models.py`), OpenEnv manifest (`openenv.yaml`).
25
+
26
+ ---
27
+
28
+ ## Action space
29
+
30
+ Each step submits JSON: `{"action": {"action_type": "<type>", "payload": { ... }}}`. Payloads are validated per task (allowed files, SQL policy, email enabled only on the hard task).
31
+
32
+ | `action_type` | Payload fields | Role |
33
+ | ------------- | ---------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
34
+ | `ExecuteSQL` | `query` (string, 1–2000 chars) | Run task-scoped SQL against the episode SQLite DB. |
35
+ | `ReadFile` | `filepath` (string, 1–255 chars) | Read an allowed file from the episode workspace. |
36
+ | `WriteFile` | `filepath`, `content` (content ≤ 1M chars) | Overwrite an allowed workspace file. |
37
+ | `RunScript` | `filepath` (must be `*.py` basename), `args` (optional list of strings, ≤ 20 args, each ≤ 500 chars) | Execute a Python script in the workspace with optional CLI args. |
38
+ | `SendEmail` | `to_email`, `subject`, `body` | Queue a mock email (used for the hard task). |
39
+
40
+ Machine-readable schema: **`GET /schema`** → `action`, or **`GET /tasks`** → `action_schema`.
41
+
42
+ ---
43
+
44
+ ## Observation space
45
+
46
+ Each `step` / `reset` response includes an observation object (REST also exposes wrapper fields such as `reward` / `done`). The fields below describe the **DataOps** layer; the OpenEnv base also defines `done`, `reward`, and `metadata`.
47
+
48
+ | Field | Type | Meaning |
49
+ | ----------------------- | ------------------------ | ----------------------------------------------------------------- |
50
+ | `done` | boolean | Whether the episode has ended (step limit or terminal condition). |
51
+ | `reward` | number \| null | Shaped **step reward** after this transition (trajectory signal). |
52
+ | `metadata` | object | OpenEnv extension bucket (usually empty). |
53
+ | `status` | `"success"` \| `"error"` | Whether the action executed successfully. |
54
+ | `message` | string | Short human-readable summary. |
55
+ | `stdout` | string \| null | Captured stdout (e.g. script or file read). |
56
+ | `stderr` | string \| null | Captured stderr. |
57
+ | `sql_results` | list of objects \| null | Row dicts for successful `SELECT`-style outcomes. |
58
+ | `email_delivery_status` | string \| null | Mock send confirmation when applicable. |
59
+ | `step_count` | integer | Steps taken in the episode. |
60
+ | `max_steps` | integer | Episode step budget. |
61
+
62
+ **Terminal evaluation:** The **grader score** in **[0.0, 1.0]** is returned by **`GET /grader`** (or **`GET /grader/{task_id}`**) and reflects the **final** database, files, and outbox (and, for the hard task, **provenance** constraints). Hackathon-style evaluations typically treat the **grader** as the primary benchmark metric; step rewards remain a supplementary signal. Successful actions can still return **`reward=0.0`** when they neither improve grader state nor unlock a milestone.
63
+
64
+ Machine-readable schema: **`GET /schema`** → `observation`.
65
+
66
+ ---
67
+
68
+ ## Tasks (descriptions and expected difficulty)
69
+
70
+ | Task ID | Expected difficulty | Description |
71
+ | ---------------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
72
+ | `task_1_easy_anomaly` | **Easy** | The `transactions` table contains valid rows and rows with **NULL** `amount`. The agent must **delete only** the corrupted rows and leave all valid rows **unchanged**, including legitimate seeded zero-value or negative non-null adjustments. |
73
+ | `task_2_medium_syntax` | **Medium** | `broken_pipeline.py` is a seeded ETL normalization job with broken filtering, priority logic, and ordering. The agent must **read**, **patch**, and **run** the script so **`process_data_stream`** produces the correct downstream-ready records on both visible and hidden seeded batches. |
74
+ | `task_3_hard_e2e` | **Hard** | **End-to-end incident:** query the correct **`daily_reports`** slice for the **scenario date**, persist results as **`report_data.json`**, **repair** **`format_report.py`**, **run** it on that JSON, then **send exactly one** email whose **body matches** the formatter output, with scenario-specific **recipient** and **subject**. |
75
+
76
+ Task list, difficulty labels, and allowed actions per task: **`GET /tasks`** and **`openenv.yaml`**.
77
+
78
+ ---
79
+
80
+ ## Setup and usage
81
+
82
+ **Prerequisites:** Python **3.12+**, **[uv](https://docs.astral.sh/uv/)**.
83
+
84
+ ```bash
85
+ uv sync
86
+ cp .env.example .env.dev
87
+ printf 'ENV_FILE=.env.dev\n' > .env
88
+ ```
89
+
90
+ Repo-root **`.env`** selects the active secondary env file. Use **`.env.dev`** for local runtime/model configuration. Hosted deployments that inject environment variables directly can skip both files.
91
+
92
+ **Run the server** from the repository root so **`HOST`**, **`PORT`**, and **`DEBUG`** from the active env file are honored:
93
+
94
+ ```bash
95
+ uv run python -m server
96
+ ```
97
+
98
+ Clients reuse the **`Set-Cookie`** session cookie or **`X-Session-ID`** header from **`POST /reset`** on **`/step`**, **`/state`**, and **`/grader`**.
99
+
100
+ **OpenEnv packaging:**
101
+
102
+ ```bash
103
+ uv run openenv validate
104
+ ```
105
+
106
+ **Docker:**
107
+
108
+ ```bash
109
+ printf 'ENV_FILE=.env.dev\n' > .env
110
+ bash build_and_run_image.sh
111
+ ```
112
+
113
+ The helper script reads repo-root **`.env`** only to resolve **`ENV_FILE`**, then passes that secondary file to `docker run --env-file ...`. The container does **not** receive a merged view of repo-root **`.env`** plus the secondary file. Keeping `.env*` out of the image is intentional; runtime configuration is injected from the host.
114
+
115
+ **Baseline inference (local):**
116
+
117
+ | Variable | Purpose |
118
+ | ---------------------- | -------------------------------------------------------------------------------------- |
119
+ | `ENV_BASE_URL` | Environment server URL (default `http://127.0.0.1:$PORT`, with `PORT=7860` by default) |
120
+ | `API_KEY` / `HF_TOKEN` | Exactly one model access credential source |
121
+ | `API_BASE_URL` | Optional model provider base URL override |
122
+ | `MODEL_NAME` | Optional Chat model ID |
123
+
124
+ ```bash
125
+ export ENV_BASE_URL=http://127.0.0.1:7860
126
+ uv run python inference.py --seed 7 --max-turns 12
127
+ ```
128
+
129
+ If **`API_BASE_URL`** is unset, `inference.py` defaults to Google's OpenAI-compatible Gemini endpoint for **`API_KEY`** and Hugging Face's router for **`HF_TOKEN`**.
130
+
131
+ Flags: `--task` (repeatable), `--seed`, `--max-turns`, `--json-scores` (emits one JSON object on stdout after the harness lines, including raw grader payloads when available). When `PUBLIC_GRADER_DETAILS=true` and the grader API exposes details, `inference.py` also writes the per-task grader payloads to `stderr`.
132
+
133
+ **`POST /baseline`** runs the same script inside the server process; optional JSON body: `task_ids`, `seed`, `max_turns`. If **`ADMIN_API_KEY`** is unset, the route is open. If **`ADMIN_API_KEY`** is set, callers must send **`X-Admin-Key`**. If **`ENV_BASE_URL`** is unset, the server injects **`http://127.0.0.1:$PORT`** into the child process automatically.
134
+
135
+ Agent-executed Python scripts run with a stripped environment, bounded resources, and capped captured output so task verification does not inherit model-provider secrets from the server process.
136
+
137
+ **Minimal HTTP smoke test:**
138
+
139
+ ```bash
140
+ curl -c cookies.txt -X POST 'http://127.0.0.1:7860/reset?task_id=task_1_easy_anomaly' \
141
+ -H 'Content-Type: application/json' \
142
+ -d '{"seed": 7}'
143
+
144
+ curl -b cookies.txt -X POST 'http://127.0.0.1:7860/step' \
145
+ -H 'Content-Type: application/json' \
146
+ -d '{"action":{"action_type":"ExecuteSQL","payload":{"query":"DELETE FROM transactions WHERE amount IS NULL"}}}'
147
+
148
+ curl -b cookies.txt 'http://127.0.0.1:7860/grader'
149
+ ```
150
+
151
+ By default **`/grader`** returns **`task_id`** and **`score`** only. Full grader **`details`** require **`PUBLIC_GRADER_DETAILS=true`** or a valid **`X-Admin-Key`** when **`ADMIN_API_KEY`** is set. This does **not** change the mandatory `[START]` / `[STEP]` / `[END]` lines from `inference.py`; it affects the grader API, the optional trailing JSON emitted by `--json-scores`, and the captured `stderr` payloads written by `inference.py`.
152
+
153
+ ---
154
+
155
+ ## Baseline scores
156
+
157
+ All figures are **terminal grader** scores in **[0.0, 1.0]**. Scores depend on provider, model revision, temperature, and `seed`.
158
+
159
+ ### Null baseline (no agent actions)
160
+
161
+ | Condition | `task_1` | `task_2` | `task_3` | Avg |
162
+ | ---------------------------------------------------- | -------- | -------- | -------- | ---- |
163
+ | `reset` only (`seed=7`), then grader; **no** `/step` | 0.00 | 0.00 | 0.00 | 0.00 |
164
+
165
+ ### Reference tool-calling baseline
166
+
167
+ `[END] success=true` in the harness logs means the terminal grader reached **1.0** for that task.
168
+
169
+ | Model | Seed | `task_1_easy_anomaly` | `task_2_medium_syntax` | `task_3_hard_e2e` | Average |
170
+ | ------------------------------- | ---- | --------------------- | ---------------------- | ----------------- | ------- |
171
+ | `gemini-3.1-flash-lite-preview` | 7 | 1.00 | 1.00 | 1.00 | 1.00 |
172
+
173
+ **Reproducing a baseline run:** With the API server running locally on `7860` and model credentials configured, run:
174
+
175
+ ```bash
176
+ export MODEL_NAME=gemini-3.1-flash-lite-preview
177
+ export ENV_BASE_URL=http://127.0.0.1:7860
178
+ uv run python inference.py --seed 7 --max-turns 12 --json-scores
179
+ ```
180
+
181
+ The final line of stdout is a single JSON object with **`scores`**, **`grades`**, **`average`**, **`model`**, and **`metadata`**.
182
+
183
+ ---
184
+
185
+ ## Hugging Face Spaces
186
+
187
+ There are two methods for running the baseline against a deployed Hugging Face Space:
188
+
189
+ 1. Running **`inference.py`** externally against the public Space URL:
190
+
191
+ ```bash
192
+ export ENV_BASE_URL=https://visheshrathi-dataops-env.hf.space
193
+ uv run python inference.py --seed 7 --max-turns 12 --json-scores
194
+ ```
195
+
196
+ In this mode, the Space only needs to expose the environment API (`/reset`, `/step`, `/grader`, `/tasks`, `/schema`, `/health`, `/metadata`, `/ws`, `/mcp`). Model credentials are provided on the machine that runs **`inference.py`**, not on the Space.
197
+
198
+ 2. Hitting **`/baseline`** API with a `POST` request:
199
+
200
+ ```bash
201
+ curl -X POST 'https://visheshrathi-dataops-env.hf.space/baseline' \
202
+ -H 'Content-Type: application/json' \
203
+ -d '{"seed": 7, "max_turns": 12}'
204
+ ```
205
+
206
+ In this mode, the Space itself executes **`inference.py`**. Configure one model credential source on the Space (**`API_KEY`** or **`HF_TOKEN`**). **`MODEL_NAME`** and **`API_BASE_URL`** are optional overrides. **`ENV_BASE_URL`** is not required for **`POST /baseline`** because the server injects **`http://127.0.0.1:$PORT`** when it launches the child `inference.py` process. If **`ADMIN_API_KEY`** is unset, **`POST /baseline`** is open; if it is set, callers must send **`X-Admin-Key`**.
207
+
208
+ ---
209
+
210
+ ## API reference
211
+
212
+ | Method | Path | Purpose |
213
+ | ------ | -------------------- | ------------------------------------------------------------- |
214
+ | GET | `/health` | Liveness |
215
+ | GET | `/metadata` | Name, description, version, task count |
216
+ | GET | `/schema` | JSON Schemas: action, observation, state |
217
+ | GET | `/tasks` | Tasks + action/observation/state schemas |
218
+ | POST | `/mcp` | Minimal JSON-RPC tool-list compatibility stub |
219
+ | POST | `/reset?task_id=...` | New episode; body may include `seed`, `episode_id` |
220
+ | POST | `/step` | One action; optional `timeout_s` |
221
+ | GET | `/state` | Episode state (`task_id`, `seed`, …) |
222
+ | GET | `/grader` | Terminal score for active task |
223
+ | GET | `/grader/{task_id}` | Same; `task_id` must match the active task |
224
+ | POST | `/baseline` | Subprocess baseline (see [Setup and usage](#setup-and-usage)) |
225
+ | WS | `/ws` | OpenEnv WebSocket session |
226
+
227
+ ---
228
+
229
+ ## Environment variables (server / container)
230
+
231
+ | Variable | Purpose |
232
+ | ------------------------ | --------------------------------------------------------------------------------------------- |
233
+ | `HOST` | Listen host used by `python -m server` and the container entrypoint |
234
+ | `PORT` | Listen port used by `python -m server` and the container entrypoint |
235
+ | `DEBUG` | Enables reload for local `python -m server` runs |
236
+ | `ENV_FILE` | Repo-relative dotenv loaded after `.env` (override) |
237
+ | `HTTP_SESSION_TIMEOUT_S` | HTTP session idle TTL; max wall time for **`POST /baseline`** child |
238
+ | `MAX_HTTP_SESSIONS` | Concurrent HTTP sessions cap |
239
+ | `MAX_WS_SESSIONS` | Concurrent WebSocket sessions cap |
240
+ | `ADMIN_API_KEY` | When set, protects **`POST /baseline`** and lets **`X-Admin-Key`** unlock full grader details |
241
+ | `PUBLIC_GRADER_DETAILS` | If `true`, public **`/grader`** and **`/grader/{task_id}`** responses include **`details`** |
242
+ | `COOKIE_SECURE` | Set `Secure` on session cookies (HTTPS) |
243
+ | `CORS_ALLOW_ORIGINS` | Comma-separated origins; empty disables permissive CORS (recommended default) |
244
+
245
+ ---
246
+
247
+ ## Tests
248
+
249
+ ```bash
250
+ uv sync --extra dev
251
+ uv run pytest -q
252
+ ```
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataOps Environment — OpenEnv-compliant enterprise data pipeline remediation environment."""
2
+
3
+ try:
4
+ from .client import DataOpsEnv
5
+ from .models import DataOpsAction, DataOpsObservation
6
+ except ImportError: # pragma: no cover — flat imports when loaded as top-level __init__ (e.g. pytest)
7
+ from client import DataOpsEnv
8
+ from models import DataOpsAction, DataOpsObservation
9
+
10
+ __all__ = [
11
+ "DataOpsAction",
12
+ "DataOpsObservation",
13
+ "DataOpsEnv",
14
+ ]
client.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed clients for the DataOpsEnv environment."""
2
+
3
+ from typing import Optional
4
+
5
+ import requests
6
+
7
+ from openenv.core.client_types import StepResult
8
+ from openenv.core.env_client import EnvClient
9
+
10
+ from models import DataOpsAction, DataOpsObservation, DataOpsState
11
+
12
+
13
+ class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]):
14
+ """Native OpenEnv WebSocket client for persistent sessions."""
15
+
16
+ def _step_payload(self, action: DataOpsAction) -> dict:
17
+ return action.model_dump()
18
+
19
+ def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]:
20
+ observation = DataOpsObservation(
21
+ **payload.get("observation", {}),
22
+ reward=payload.get("reward"),
23
+ done=payload.get("done", False),
24
+ )
25
+ return StepResult(
26
+ observation=observation,
27
+ reward=payload.get("reward"),
28
+ done=payload.get("done", False),
29
+ )
30
+
31
+ def _parse_state(self, payload: dict) -> DataOpsState:
32
+ return DataOpsState(**payload)
33
+
34
+
35
+ class DataOpsEnvClient:
36
+ """Compatibility HTTP client for the validator-facing REST API."""
37
+
38
+ def __init__(
39
+ self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0
40
+ ) -> None:
41
+ self.base_url = base_url.rstrip("/")
42
+ self.timeout = timeout
43
+ self._session = requests.Session()
44
+
45
+ @staticmethod
46
+ def _parse_observation(payload: dict) -> DataOpsObservation:
47
+ observation_payload = dict(payload.get("observation", {}))
48
+ if "reward" in payload:
49
+ observation_payload["reward"] = payload["reward"]
50
+ if "done" in payload:
51
+ observation_payload["done"] = payload["done"]
52
+ return DataOpsObservation(**observation_payload)
53
+
54
+ def reset(
55
+ self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None,
56
+ ) -> DataOpsObservation:
57
+ resp = self._session.post(
58
+ f"{self.base_url}/reset",
59
+ params={"task_id": task_id},
60
+ json={"seed": seed},
61
+ timeout=self.timeout,
62
+ )
63
+ resp.raise_for_status()
64
+ return self._parse_observation(resp.json())
65
+
66
+ def step(self, action: DataOpsAction) -> DataOpsObservation:
67
+ resp = self._session.post(
68
+ f"{self.base_url}/step",
69
+ json={"action": action.model_dump()},
70
+ timeout=self.timeout,
71
+ )
72
+ resp.raise_for_status()
73
+ return self._parse_observation(resp.json())
74
+
75
+ def state(self) -> DataOpsState:
76
+ resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout)
77
+ resp.raise_for_status()
78
+ return DataOpsState(**resp.json())
79
+
80
+ def grade(self, task_id: Optional[str] = None) -> dict:
81
+ url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader"
82
+ resp = self._session.get(url, timeout=self.timeout)
83
+ resp.raise_for_status()
84
+ return resp.json()
85
+
86
+ def close(self) -> None:
87
+ self._session.close()
88
+
89
+ def __enter__(self) -> "DataOpsEnvClient":
90
+ return self
91
+
92
+ def __exit__(self, *args: object) -> None:
93
+ self.close()
data/__init__.py ADDED
File without changes
data/init_db.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import sqlite3
5
+
6
+ from server.task_specs import TaskScenarioBundle, build_task_scenario
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace")
11
+ WORKSPACE_DIR = WORKSPACE_ROOT
12
+
13
+
14
+ def setup_workspace(
15
+ workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None
16
+ ) -> str:
17
+ """Initialise an isolated episode workspace from the seeded scenario."""
18
+ target_workspace = workspace_dir or WORKSPACE_DIR
19
+ target_db_path = os.path.join(target_workspace, "mock_warehouse.db")
20
+ resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0)
21
+ os.makedirs(target_workspace, exist_ok=True)
22
+
23
+ _clear_workspace(target_workspace)
24
+ _init_database(target_db_path, resolved_scenario)
25
+ _write_seeded_files(target_workspace, resolved_scenario)
26
+
27
+ logger.info(
28
+ "Workspace reset complete: task=%s seed=%s db=%s",
29
+ resolved_scenario.task_id,
30
+ resolved_scenario.seed,
31
+ target_db_path,
32
+ )
33
+ return target_db_path
34
+
35
+
36
+ def _clear_workspace(workspace_dir: str) -> None:
37
+ for entry in os.listdir(workspace_dir):
38
+ path = os.path.join(workspace_dir, entry)
39
+ try:
40
+ if os.path.isdir(path):
41
+ shutil.rmtree(path)
42
+ else:
43
+ os.remove(path)
44
+ except FileNotFoundError:
45
+ continue
46
+
47
+
48
+ def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None:
49
+ conn = sqlite3.connect(db_path)
50
+ try:
51
+ c = conn.cursor()
52
+ c.execute(
53
+ """
54
+ CREATE TABLE transactions (
55
+ id INTEGER PRIMARY KEY,
56
+ user_id INTEGER NOT NULL,
57
+ amount REAL,
58
+ status TEXT NOT NULL
59
+ )
60
+ """
61
+ )
62
+ c.execute(
63
+ """
64
+ CREATE TABLE daily_reports (
65
+ id INTEGER PRIMARY KEY,
66
+ report_date TEXT NOT NULL,
67
+ department TEXT NOT NULL,
68
+ revenue REAL NOT NULL,
69
+ expenses REAL NOT NULL,
70
+ headcount INTEGER NOT NULL
71
+ )
72
+ """
73
+ )
74
+
75
+ if scenario.task_1:
76
+ c.executemany(
77
+ "INSERT INTO transactions VALUES (?, ?, ?, ?)",
78
+ [
79
+ (row["id"], row["user_id"], row["amount"], row["status"])
80
+ for row in scenario.task_1.all_rows
81
+ ],
82
+ )
83
+ else:
84
+ c.executemany(
85
+ "INSERT INTO transactions VALUES (?, ?, ?, ?)",
86
+ [(1, 9000, 100.0, "success")],
87
+ )
88
+
89
+ if scenario.task_3:
90
+ c.executemany(
91
+ "INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)",
92
+ [
93
+ (
94
+ row["id"],
95
+ row["report_date"],
96
+ row["department"],
97
+ row["revenue"],
98
+ row["expenses"],
99
+ row["headcount"],
100
+ )
101
+ for row in scenario.task_3.all_rows
102
+ ],
103
+ )
104
+
105
+ conn.commit()
106
+ finally:
107
+ conn.close()
108
+
109
+
110
+ def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None:
111
+ if scenario.task_2:
112
+ with open(
113
+ os.path.join(workspace_dir, "broken_pipeline.py"),
114
+ "w",
115
+ encoding="utf-8",
116
+ ) as f:
117
+ f.write(scenario.task_2.broken_script)
118
+
119
+ if scenario.task_3:
120
+ with open(
121
+ os.path.join(workspace_dir, "format_report.py"),
122
+ "w",
123
+ encoding="utf-8",
124
+ ) as f:
125
+ f.write(scenario.task_3.broken_script)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ logging.basicConfig(level=logging.INFO)
130
+ setup_workspace()
env_loader.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load local env files, while allowing externally injected container env vars."""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ from dotenv import dotenv_values, load_dotenv
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _PROJECT_ROOT = Path(__file__).resolve().parent
12
+ _RUNTIME_ENV_KEYS = (
13
+ "ENV_FILE",
14
+ "HOST",
15
+ "PORT",
16
+ "DEBUG",
17
+ "ENV_BASE_URL",
18
+ "ADMIN_API_KEY",
19
+ "PUBLIC_GRADER_DETAILS",
20
+ "COOKIE_SECURE",
21
+ "HTTP_SESSION_TIMEOUT_S",
22
+ "CORS_ALLOW_ORIGINS",
23
+ "MAX_HTTP_SESSIONS",
24
+ "MAX_WS_SESSIONS",
25
+ "API_KEY",
26
+ "HF_TOKEN",
27
+ "MODEL_NAME",
28
+ "API_BASE_URL",
29
+ )
30
+
31
+
32
+ def _has_external_runtime_config() -> bool:
33
+ return any(bool(os.getenv(key, "").strip()) for key in _RUNTIME_ENV_KEYS)
34
+
35
+
36
+ def _resolve_env_file(env_file_name: str) -> Path:
37
+ env_path = (_PROJECT_ROOT / env_file_name).resolve()
38
+ try:
39
+ env_path.relative_to(_PROJECT_ROOT)
40
+ except ValueError as exc:
41
+ raise ValueError(
42
+ f"ENV_FILE '{env_file_name}' must resolve inside the project root."
43
+ ) from exc
44
+ return env_path
45
+
46
+
47
+ def load_env() -> None:
48
+ """Read repo-root .env to locate the active secondary env file."""
49
+ dot_env = _PROJECT_ROOT / ".env"
50
+ if not dot_env.exists():
51
+ if _has_external_runtime_config():
52
+ logger.debug(
53
+ ".env not found at %s — assuming runtime env vars were injected externally",
54
+ dot_env,
55
+ )
56
+ return
57
+ logger.warning(
58
+ ".env not found at %s — local runs expect it to define ENV_FILE for the active env file",
59
+ dot_env,
60
+ )
61
+ return
62
+
63
+ load_dotenv(dot_env, override=False)
64
+ env_file_name = str(
65
+ (dotenv_values(dot_env).get("ENV_FILE") or os.getenv("ENV_FILE") or "")
66
+ ).strip()
67
+ if not env_file_name:
68
+ logger.debug(".env did not specify ENV_FILE — no secondary file loaded")
69
+ return
70
+
71
+ try:
72
+ env_file = _resolve_env_file(env_file_name)
73
+ except ValueError as exc:
74
+ logger.warning("%s", exc)
75
+ return
76
+
77
+ if not env_file.exists():
78
+ logger.warning("ENV_FILE '%s' not found at %s", env_file_name, env_file)
79
+ return
80
+
81
+ load_dotenv(env_file, override=True)
82
+ logger.debug("Loaded environment variables from %s", env_file)
inference.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataOps benchmark runner: drives the sandbox over HTTP (`/reset`, `/step`, `/grader`) with an OpenAI
3
+ tool-calling loop. Tool schemas are task-scoped (e.g. send_email only for the hard E2E task).
4
+
5
+ Flow per task: reset → chat completions (prefer `tool_choice="required"`) → validate tool args → POST each action →
6
+ append tool/observation messages until the env reports `done` or `max_turns` → GET grader score. Success is
7
+ derived from the score vs `SUCCESS_SCORE_THRESHOLD`.
8
+
9
+ Stdout is the harness protocol only: one `[START]`, one `[STEP]` per env step, one `[END]` (always). Use
10
+ `--json-scores` to append a single JSON object (scores, average, metadata) for `/baseline` ingestion.
11
+
12
+ CLI: `--task` (repeatable), `--seed`, `--max-turns`, `--json-scores`. The environment HTTP base URL comes from
13
+ `ENV_BASE_URL`, or if unset `http://127.0.0.1:$PORT` (default port 7860). Auth uses either `API_KEY` or
14
+ `HF_TOKEN`. `API_BASE_URL` is optional: when omitted, the runner defaults to Google's OpenAI-compatible Gemini
15
+ endpoint for `API_KEY` and Hugging Face's router for `HF_TOKEN`.
16
+
17
+ Library logging is disabled so parsers see only these lines.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import asyncio
24
+ import json
25
+ import logging
26
+ import os
27
+ import re
28
+ import sys
29
+ import zlib
30
+ from datetime import datetime, timezone
31
+ from typing import Any, Optional, Type
32
+
33
+ import requests
34
+ from openai import BadRequestError, OpenAI
35
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
36
+ from pydantic import BaseModel, ValidationError
37
+
38
+ from env_loader import load_env
39
+ from models import (
40
+ ExecuteSQLPayload,
41
+ ReadFilePayload,
42
+ RunScriptPayload,
43
+ SendEmailPayload,
44
+ WriteFilePayload,
45
+ )
46
+ from server.task_specs import TASK_IDS, TASK_METADATA
47
+
48
+ # Silence all library logging (httpx, openai, urllib3, env_loader, etc.).
49
+ logging.disable(logging.CRITICAL)
50
+
51
+ load_env()
52
+
53
+ DEFAULT_GOOGLE_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
54
+ DEFAULT_HF_OPENAI_BASE_URL = "https://router.huggingface.co/v1"
55
+
56
+ _DEFAULT_PORT = int(os.getenv("PORT", "7860"))
57
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL") or f"http://127.0.0.1:{_DEFAULT_PORT}"
58
+ MODEL_NAME = os.getenv("MODEL_NAME") or "gemini-3.1-flash-lite-preview"
59
+
60
+ BENCHMARK = "dataops_env"
61
+ MAX_TURNS = 12
62
+ SUCCESS_SCORE_THRESHOLD = 1.0
63
+
64
+ _TOOL_HELP: dict[str, str] = {
65
+ "execute_sql": "execute_sql — SQL over the task warehouse (field: query).",
66
+ "read_file": "read_file — read a workspace file (field: filepath).",
67
+ "write_file": "write_file — overwrite a file (fields: filepath, content).",
68
+ "invoke_python": "invoke_python — run a Python script (fields: filepath, optional args).",
69
+ "send_email": "send_email — send email (fields: to_email, subject, body).",
70
+ }
71
+
72
+ _ACTION_TO_TOOL: dict[str, str] = {
73
+ "ExecuteSQL": "execute_sql",
74
+ "ReadFile": "read_file",
75
+ "WriteFile": "write_file",
76
+ "RunScript": "invoke_python",
77
+ "SendEmail": "send_email",
78
+ }
79
+
80
+
81
+ def _allowed_tool_names_csv(task_id: str) -> str:
82
+ order = (
83
+ "execute_sql",
84
+ "read_file",
85
+ "write_file",
86
+ "invoke_python",
87
+ "send_email",
88
+ )
89
+ allowed = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
90
+ return ", ".join(t for t in order if t in allowed)
91
+
92
+
93
+ def _system_prompt_for_task(task_id: str) -> str:
94
+ lines = [
95
+ _TOOL_HELP[t]
96
+ for t in (
97
+ "execute_sql",
98
+ "read_file",
99
+ "write_file",
100
+ "invoke_python",
101
+ "send_email",
102
+ )
103
+ if t in {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
104
+ ]
105
+ tools_block = "\n".join(f" - {line}" for line in lines)
106
+ return f"""\
107
+ You are an expert DataOps agent in a task-scoped benchmark. Only the tools listed below exist for this task — do not assume other actions are available.
108
+
109
+ Available tools:
110
+ {tools_block}
111
+
112
+ Rules:
113
+ - Always read files before modifying them when read_file is available.
114
+ - After writing a fix, run the script to verify it works when invoke_python is available.
115
+ - Be precise. Do not drop tables. Do not guess — inspect first.
116
+ - For tasks that include send_email, match subject and body to the task description exactly.
117
+ """
118
+
119
+
120
+ TASK_PROMPTS = {
121
+ "task_1_easy_anomaly": (
122
+ "Solve the seeded cleanup task carefully. Inspect before mutating. Only NULL-amount rows are corrupted; preserve every non-null row exactly, including legitimate zero or negative adjustments."
123
+ ),
124
+ "task_2_medium_syntax": (
125
+ "Solve the seeded script-repair task. Read the file, make the minimal correct fix, and verify with execution."
126
+ ),
127
+ "task_3_hard_e2e": (
128
+ "Solve the seeded incident task end to end. Use SQL for the exact slice, write the exact JSON file, "
129
+ "repair the formatter, execute it, and email the exact generated report."
130
+ ),
131
+ }
132
+
133
+
134
+ def log_start(task: str, env: str, model: str) -> None:
135
+ print(f"[START] task={task} env={env} model={model}", flush=True)
136
+
137
+
138
+ def log_step(
139
+ step: int, action: str, reward: float, done: bool, error: Optional[str]
140
+ ) -> None:
141
+ error_val = error if error else "null"
142
+ done_val = str(done).lower()
143
+ print(
144
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
145
+ flush=True,
146
+ )
147
+
148
+
149
+ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
150
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
151
+ print(
152
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
153
+ flush=True,
154
+ )
155
+
156
+
157
+ def _public_grader_details_enabled() -> bool:
158
+ return os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower() in {"1", "true", "yes"}
159
+
160
+
161
+ def _emit_grader_details_to_stderr(grade: dict[str, Any]) -> None:
162
+ if not _public_grader_details_enabled():
163
+ return
164
+ if "details" not in grade:
165
+ return
166
+ print(json.dumps(grade, ensure_ascii=False), file=sys.stderr, flush=True)
167
+
168
+
169
+ def _request_json(
170
+ http: requests.Session,
171
+ method: str,
172
+ path: str,
173
+ *,
174
+ timeout: float,
175
+ **kwargs: Any,
176
+ ) -> dict[str, Any]:
177
+ response = http.request(method, f"{ENV_BASE_URL}{path}", timeout=timeout, **kwargs)
178
+ response.raise_for_status()
179
+ return response.json()
180
+
181
+
182
+ def _build_tools(task_id: str) -> list[ChatCompletionToolParam]:
183
+ defs: dict[str, tuple[str, Type[BaseModel]]] = {
184
+ "execute_sql": (
185
+ "Run a task-scoped SQL query against the SQLite warehouse DB.",
186
+ ExecuteSQLPayload,
187
+ ),
188
+ "read_file": ("Read a file in the workspace.", ReadFilePayload),
189
+ "write_file": ("Overwrite a file with new content.", WriteFilePayload),
190
+ "invoke_python": (
191
+ "Execute a Python script in the workspace (optional args).",
192
+ RunScriptPayload,
193
+ ),
194
+ "send_email": ("Send a formatted email notification.", SendEmailPayload),
195
+ }
196
+ allowed_names = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
197
+ return [
198
+ {
199
+ "type": "function",
200
+ "function": {
201
+ "name": name,
202
+ "description": defs[name][0],
203
+ "parameters": defs[name][1].model_json_schema(),
204
+ },
205
+ }
206
+ for name in (
207
+ "execute_sql",
208
+ "read_file",
209
+ "write_file",
210
+ "invoke_python",
211
+ "send_email",
212
+ )
213
+ if name in allowed_names
214
+ ]
215
+
216
+
217
+ def _tool_call_to_action(name: str, arguments: str) -> dict[str, Any]:
218
+ if name == "run_script":
219
+ name = "invoke_python"
220
+ mapping: dict[str, tuple[str, Type[BaseModel]]] = {
221
+ "execute_sql": ("ExecuteSQL", ExecuteSQLPayload),
222
+ "read_file": ("ReadFile", ReadFilePayload),
223
+ "write_file": ("WriteFile", WriteFilePayload),
224
+ "invoke_python": ("RunScript", RunScriptPayload),
225
+ "send_email": ("SendEmail", SendEmailPayload),
226
+ }
227
+ if name not in mapping:
228
+ raise ValueError(f"Unknown tool: {name}")
229
+ action_type, model = mapping[name]
230
+ data = json.loads(arguments) if (arguments or "").strip() else {}
231
+ payload = model.model_validate(data).model_dump()
232
+ return {"action_type": action_type, "payload": payload}
233
+
234
+
235
+ _MALFORMED_TOOL = re.compile(
236
+ r"^([a-zA-Z_][a-zA-Z0-9_]*)[\s,=\(]+(\{.*\})\)?\s*$", re.DOTALL
237
+ )
238
+
239
+
240
+ def _normalize_tool_name_and_args(name: str, arguments: str) -> tuple[str, str]:
241
+ name = (name or "").strip()
242
+ arguments = (arguments or "").strip()
243
+ m = _MALFORMED_TOOL.match(name)
244
+ if m:
245
+ base, embedded = m.group(1).strip(), m.group(2).strip()
246
+ if not arguments:
247
+ return base, embedded
248
+ return name, arguments
249
+
250
+
251
+ def _action_from_tool_call(tc: Any) -> dict[str, Any]:
252
+ name, arguments = _normalize_tool_name_and_args(
253
+ tc.function.name or "", tc.function.arguments or ""
254
+ )
255
+ return _tool_call_to_action(name, arguments)
256
+
257
+
258
+ def _action_str(action_payload: dict[str, Any]) -> str:
259
+ at = action_payload.get("action_type", "")
260
+ pl = action_payload.get("payload") or {}
261
+ raw = f"{at}({json.dumps(pl, ensure_ascii=False)})"
262
+ if len(raw) > 1200:
263
+ return raw[:600] + "..." + raw[-550:]
264
+ return raw
265
+
266
+
267
+ def _obs_error(obs: dict[str, Any]) -> Optional[str]:
268
+ if obs.get("status") != "error":
269
+ return None
270
+ msg = obs.get("message")
271
+ if isinstance(msg, str) and msg.strip():
272
+ return msg.replace("\n", " ").strip()
273
+ return None
274
+
275
+
276
+ def _resolve_api_base_url() -> str:
277
+ explicit = os.getenv("API_BASE_URL", "").strip()
278
+ if explicit:
279
+ return explicit
280
+ if os.getenv("HF_TOKEN", "").strip():
281
+ return DEFAULT_HF_OPENAI_BASE_URL
282
+ return DEFAULT_GOOGLE_OPENAI_BASE_URL
283
+
284
+
285
+ API_BASE_URL = _resolve_api_base_url()
286
+
287
+
288
+ def _openai_client() -> OpenAI:
289
+ key = (os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "").strip()
290
+ if not key:
291
+ print(
292
+ "[inference] Missing API_KEY or HF_TOKEN for model access.",
293
+ file=sys.stderr,
294
+ flush=True,
295
+ )
296
+ sys.exit(1)
297
+ return OpenAI(api_key=key, base_url=API_BASE_URL)
298
+
299
+
300
+ def _llm_seed(env_seed: int | None, task_id: str) -> int | None:
301
+ if env_seed is None:
302
+ return None
303
+ mixed = (int(env_seed) * 1_000_003) ^ (zlib.crc32(task_id.encode()) & 0xFFFFFFFF)
304
+ return mixed & 0x7FFFFFFF
305
+
306
+
307
+ def _create_chat_completion(
308
+ client: OpenAI,
309
+ messages: list[ChatCompletionMessageParam],
310
+ tools: list[ChatCompletionToolParam],
311
+ *,
312
+ task_id: str,
313
+ env_seed: int | None,
314
+ ) -> Any:
315
+ """Prefer tool_choice=required so the model cannot end a turn without a tool call."""
316
+ kwargs: dict[str, Any] = {
317
+ "model": MODEL_NAME,
318
+ "messages": messages,
319
+ "tools": tools,
320
+ "parallel_tool_calls": False,
321
+ "temperature": 0,
322
+ "top_p": 1.0,
323
+ }
324
+ llm_seed = _llm_seed(env_seed, task_id)
325
+ if llm_seed is not None:
326
+ kwargs["seed"] = llm_seed
327
+
328
+ def _call(tool_choice: str) -> Any:
329
+ return client.chat.completions.create(**kwargs, tool_choice=tool_choice)
330
+
331
+ try:
332
+ return _call("required")
333
+ except BadRequestError as e:
334
+ err = str(e).lower()
335
+ if "seed" in err and llm_seed is not None:
336
+ kwargs.pop("seed", None)
337
+ try:
338
+ return _call("required")
339
+ except BadRequestError as e2:
340
+ err = str(e2).lower()
341
+ if not any(x in err for x in ("tool_choice", "required", "unsupported")):
342
+ raise
343
+ return _call("auto")
344
+
345
+
346
+ def run_task(
347
+ client: OpenAI,
348
+ http: requests.Session,
349
+ task_id: str,
350
+ *,
351
+ max_turns: int,
352
+ seed: int | None,
353
+ ) -> float:
354
+ rewards: list[float] = []
355
+ steps_taken = 0
356
+ score = 0.0
357
+ success = False
358
+
359
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
360
+
361
+ try:
362
+ tools = _build_tools(task_id)
363
+ names_csv = _allowed_tool_names_csv(task_id)
364
+ reset_resp = _request_json(
365
+ http,
366
+ "POST",
367
+ "/reset",
368
+ timeout=10,
369
+ params={"task_id": task_id},
370
+ json={} if seed is None else {"seed": seed},
371
+ )
372
+ reset_obs = reset_resp.get("observation", reset_resp)
373
+
374
+ messages: list[ChatCompletionMessageParam] = [
375
+ {"role": "system", "content": _system_prompt_for_task(task_id)},
376
+ {
377
+ "role": "user",
378
+ "content": TASK_PROMPTS[task_id]
379
+ + f"\n\nEnvironment says: {reset_obs['message']}",
380
+ },
381
+ ]
382
+
383
+ done = False
384
+ step_num = 0
385
+ no_tool_streak = 0
386
+ for turn in range(1, max_turns + 1):
387
+ try:
388
+ response = _create_chat_completion(
389
+ client,
390
+ messages,
391
+ tools,
392
+ task_id=task_id,
393
+ env_seed=seed,
394
+ )
395
+ except BadRequestError as e:
396
+ err_str = str(e).lower()
397
+ if "tool" not in err_str and "function" not in err_str:
398
+ raise
399
+ if messages and messages[-1].get("role") == "assistant": # type: ignore[union-attr]
400
+ messages.pop()
401
+ messages.append(
402
+ {
403
+ "role": "user",
404
+ "content": (
405
+ "IMPORTANT: Call tools using ONLY these exact names: "
406
+ f"{names_csv}. "
407
+ "Put ALL parameters inside the tool's JSON arguments field. "
408
+ "Do NOT embed parameters in the tool name itself."
409
+ ),
410
+ }
411
+ )
412
+ try:
413
+ response = _create_chat_completion(
414
+ client,
415
+ messages,
416
+ tools,
417
+ task_id=task_id,
418
+ env_seed=seed,
419
+ )
420
+ except BadRequestError:
421
+ break
422
+ msg = response.choices[0].message
423
+
424
+ if not msg.tool_calls:
425
+ no_tool_streak += 1
426
+ if no_tool_streak > 3:
427
+ break
428
+ messages.append(msg) # type: ignore[arg-type]
429
+ messages.append(
430
+ {
431
+ "role": "user",
432
+ "content": (
433
+ f"You must respond with exactly one tool call ({names_csv}). "
434
+ "Do not reply with plain text only."
435
+ ),
436
+ }
437
+ )
438
+ continue
439
+ no_tool_streak = 0
440
+
441
+ messages.append(msg) # type: ignore[arg-type]
442
+
443
+ for tc in msg.tool_calls:
444
+ try:
445
+ action_payload = _action_from_tool_call(tc)
446
+ except (json.JSONDecodeError, ValidationError, ValueError) as e:
447
+ messages.append(
448
+ {
449
+ "role": "tool",
450
+ "tool_call_id": tc.id,
451
+ "content": f"Invalid tool arguments: {e}",
452
+ }
453
+ )
454
+ continue
455
+
456
+ step_num += 1
457
+ step_resp = _request_json(
458
+ http,
459
+ "POST",
460
+ "/step",
461
+ timeout=30,
462
+ json={"action": action_payload},
463
+ )
464
+ obs = step_resp.get("observation", step_resp)
465
+ reward_raw = step_resp.get("reward")
466
+ reward = 0.0 if reward_raw is None else float(reward_raw)
467
+ done = step_resp.get("done", False)
468
+
469
+ rewards.append(reward)
470
+ steps_taken = step_num
471
+ err = _obs_error(obs if isinstance(obs, dict) else {})
472
+ log_step(
473
+ step=step_num,
474
+ action=_action_str(action_payload),
475
+ reward=reward,
476
+ done=done,
477
+ error=err,
478
+ )
479
+
480
+ messages.append(
481
+ {"role": "tool", "tool_call_id": tc.id, "content": json.dumps(obs)}
482
+ )
483
+
484
+ if done:
485
+ break
486
+
487
+ if done:
488
+ break
489
+
490
+ grade = _request_json(http, "GET", f"/grader/{task_id}", timeout=10)
491
+ _emit_grader_details_to_stderr(grade)
492
+ score = float(grade["score"])
493
+ score = min(max(score, 0.0), 1.0)
494
+ success = score >= SUCCESS_SCORE_THRESHOLD
495
+ except Exception as exc:
496
+ print(
497
+ f"[inference] task={task_id} failed: {exc!r}", file=sys.stderr, flush=True
498
+ )
499
+ finally:
500
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
501
+
502
+ return score
503
+
504
+
505
+ def _parse_args() -> argparse.Namespace:
506
+ p = argparse.ArgumentParser(
507
+ description="DataOpsEnv inference (OpenAI client; protocol lines on stdout)."
508
+ )
509
+ p.add_argument(
510
+ "--task",
511
+ action="append",
512
+ choices=TASK_IDS,
513
+ dest="tasks",
514
+ help="Run only the selected task(s). Defaults to all tasks.",
515
+ )
516
+ p.add_argument(
517
+ "--seed",
518
+ type=int,
519
+ default=None,
520
+ help="Environment seed for /reset; also used for LLM seed when the API supports it.",
521
+ )
522
+ p.add_argument(
523
+ "--max-turns",
524
+ type=int,
525
+ default=MAX_TURNS,
526
+ help=f"Maximum tool-using turns per task (default: {MAX_TURNS}).",
527
+ )
528
+ p.add_argument(
529
+ "--json-scores",
530
+ action="store_true",
531
+ help="Print a final JSON object with scores to stdout (for POST /baseline).",
532
+ )
533
+ return p.parse_args()
534
+
535
+
536
+ def _run_inference_sync(args: argparse.Namespace) -> None:
537
+ client = _openai_client()
538
+ scores: dict[str, float] = {}
539
+ grades: dict[str, dict[str, Any]] = {}
540
+ task_ids = args.tasks or list(TASK_PROMPTS)
541
+ with requests.Session() as http:
542
+ for task_id in task_ids:
543
+ scores[task_id] = run_task(
544
+ client,
545
+ http,
546
+ task_id,
547
+ max_turns=max(1, int(args.max_turns)),
548
+ seed=args.seed,
549
+ )
550
+ if args.json_scores:
551
+ try:
552
+ grades[task_id] = _request_json(
553
+ http,
554
+ "GET",
555
+ f"/grader/{task_id}",
556
+ timeout=10,
557
+ )
558
+ except Exception:
559
+ grades[task_id] = {
560
+ "task_id": task_id,
561
+ "score": scores[task_id],
562
+ }
563
+
564
+ if args.json_scores:
565
+ avg = sum(scores.values()) / len(scores)
566
+ payload = {
567
+ "scores": scores,
568
+ "grades": grades,
569
+ "average": round(avg, 4),
570
+ "model": MODEL_NAME,
571
+ "metadata": {
572
+ "env_base_url": ENV_BASE_URL,
573
+ "seed": args.seed,
574
+ "max_turns": max(1, int(args.max_turns)),
575
+ "tasks": task_ids,
576
+ "generated_at_utc": datetime.now(timezone.utc).isoformat(),
577
+ "model_base_url": str(getattr(client, "base_url", "")),
578
+ },
579
+ }
580
+ print(json.dumps(payload), flush=True)
581
+
582
+
583
+ async def main() -> None:
584
+ args = _parse_args()
585
+ await asyncio.to_thread(_run_inference_sync, args)
586
+
587
+
588
+ if __name__ == "__main__":
589
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ from openenv.core.env_server import (
5
+ Action as BaseAction,
6
+ )
7
+ from openenv.core.env_server import (
8
+ Observation as BaseObservation,
9
+ )
10
+ from openenv.core.env_server import (
11
+ State as BaseState,
12
+ )
13
+ from pydantic import BaseModel, Field, field_validator
14
+
15
+ # ── Action Payload Models (Pydantic-validated) ─────────────────────
16
+
17
+
18
+ class ExecuteSQLPayload(BaseModel):
19
+ query: str = Field(..., min_length=1, max_length=2000)
20
+
21
+
22
+ class ReadFilePayload(BaseModel):
23
+ filepath: str = Field(..., min_length=1, max_length=255)
24
+
25
+
26
+ class WriteFilePayload(BaseModel):
27
+ filepath: str = Field(..., min_length=1, max_length=255)
28
+ content: str = Field(..., max_length=1_000_000)
29
+
30
+
31
+ class RunScriptPayload(BaseModel):
32
+ filepath: str = Field(..., min_length=1, max_length=255)
33
+ args: List[str] = Field(default_factory=list, max_length=20)
34
+
35
+ @field_validator("filepath")
36
+ @classmethod
37
+ def must_be_safe_script_name(cls, v: str) -> str:
38
+ basename = v.rsplit("/", 1)[-1]
39
+ if not re.match(r"^[a-zA-Z0-9_\-]+\.py$", basename):
40
+ raise ValueError("Script name must be alphanumeric with .py extension.")
41
+ return v
42
+
43
+ @field_validator("args")
44
+ @classmethod
45
+ def args_must_be_safe(cls, v: list[str]) -> list[str]:
46
+ for arg in v:
47
+ if not isinstance(arg, str) or len(arg) > 500:
48
+ raise ValueError("Each arg must be a string under 500 chars.")
49
+ return v
50
+
51
+
52
+ class SendEmailPayload(BaseModel):
53
+ to_email: str = Field(..., max_length=320)
54
+ subject: str = Field(..., min_length=1, max_length=500)
55
+ body: str = Field(..., min_length=1, max_length=100_000)
56
+
57
+ @field_validator("to_email")
58
+ @classmethod
59
+ def must_look_like_email(cls, v: str) -> str:
60
+ if not re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", v):
61
+ raise ValueError("Invalid email format.")
62
+ return v
63
+
64
+
65
+ ACTION_TYPE = Literal["ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail"]
66
+
67
+ PAYLOAD_MODELS: dict[str, type[BaseModel]] = {
68
+ "ExecuteSQL": ExecuteSQLPayload,
69
+ "ReadFile": ReadFilePayload,
70
+ "WriteFile": WriteFilePayload,
71
+ "RunScript": RunScriptPayload,
72
+ "SendEmail": SendEmailPayload,
73
+ }
74
+
75
+
76
+ # ── Action Model (extends OpenEnv Action) ──────────────────────────
77
+
78
+
79
+ class DataOpsAction(BaseAction):
80
+ action_type: ACTION_TYPE = Field(
81
+ ..., description="One of: ExecuteSQL, ReadFile, WriteFile, RunScript, SendEmail"
82
+ )
83
+ payload: Dict[str, Any] = Field(
84
+ ..., description="Parameters for the chosen action type."
85
+ )
86
+
87
+
88
+ # ── Observation Model (extends OpenEnv Observation) ────────────────
89
+
90
+
91
+ class DataOpsObservation(BaseObservation):
92
+ status: Literal["success", "error"] = "error"
93
+ message: str = ""
94
+ stdout: Optional[str] = None
95
+ stderr: Optional[str] = None
96
+ sql_results: Optional[List[Dict[str, Any]]] = None
97
+ email_delivery_status: Optional[str] = None
98
+ step_count: int = 0
99
+ max_steps: int = 0
100
+
101
+
102
+ # ── State Model (extends OpenEnv State) ────────────────────────────
103
+
104
+
105
+ class DataOpsState(BaseState):
106
+ task_id: str = ""
107
+ task_description: str = ""
108
+ seed: int = 0
109
+ max_steps: int = 15
110
+ done: bool = False
111
+ cumulative_reward: float = 0.0
112
+ actions_taken: List[str] = Field(default_factory=list)
113
+ emails_sent: int = 0
openenv.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: dataops_env
3
+ version: 1.0.0
4
+ description: Seeded enterprise DataOps benchmark with isolated sessions, deterministic graders, and three escalating remediation tasks.
5
+ type: space
6
+ runtime: fastapi
7
+ app: server.app:app
8
+ port: 7860
9
+ tasks:
10
+ - id: task_1_easy_anomaly
11
+ name: Delete Corrupted Transaction Rows
12
+ difficulty: easy
13
+ description: Inspect a seeded transaction table and remove only the rows whose amount is NULL, preserving legitimate non-null edge values.
14
+ benchmark_focus: Careful data cleanup without collateral damage.
15
+ allowed_actions:
16
+ - ExecuteSQL
17
+ - id: task_2_medium_syntax
18
+ name: Repair Seeded Pipeline Script
19
+ difficulty: medium
20
+ description: Repair a seeded ETL normalization script and verify it against visible and hidden seeded batches.
21
+ benchmark_focus: Code reading, precise repair, and generalization beyond the demo batch.
22
+ allowed_actions:
23
+ - ReadFile
24
+ - WriteFile
25
+ - RunScript
26
+ - id: task_3_hard_e2e
27
+ name: Resolve Revenue Reporting Incident
28
+ difficulty: hard
29
+ description: Extract a seeded reporting slice, repair the formatter, and send the exact generated report.
30
+ benchmark_focus: End-to-end data extraction, file repair, and communication with provenance.
31
+ allowed_actions:
32
+ - ExecuteSQL
33
+ - ReadFile
34
+ - WriteFile
35
+ - RunScript
36
+ - SendEmail
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-dataops_env"
7
+ version = "1.0.0"
8
+ description = "Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant)."
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ dependencies = [
12
+ "openenv-core[core]>=0.2.2",
13
+ "fastapi>=0.115.0",
14
+ "starlette>=0.46.0,<0.52.0",
15
+ "uvicorn[standard]>=0.34.0",
16
+ "pydantic>=2.10.0",
17
+ "pyyaml>=6.0.2",
18
+ "openai>=1.60.0",
19
+ "requests>=2.32.0",
20
+ "wsproto>=1.3.2",
21
+ "python-dotenv>=1.0.0",
22
+ "huggingface-hub>=1.8.0",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "pytest>=8.0.0",
28
+ "pytest-cov>=4.0.0",
29
+ ]
30
+
31
+ [project.scripts]
32
+ server = "dataops_env.server.app:main"
33
+
34
+ [tool.setuptools]
35
+ include-package-data = true
36
+ packages = ["dataops_env", "dataops_env.server"]
37
+ package-dir = { "dataops_env" = ".", "dataops_env.server" = "server" }
38
+
39
+ [tool.pytest.ini_options]
40
+ pythonpath = ["."]
41
+ testpaths = ["tests"]
42
+ addopts = "--import-mode=importlib"
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """DataOps environment server components."""
2
+
3
+ from .dataops_env_environment import DataOpsEnvironment
4
+
5
+ __all__ = ["DataOpsEnvironment"]
server/__main__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Run the API server from the repo root: ``uv run python -m server``."""
2
+
3
+ from server.app import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
server/app.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ from collections.abc import AsyncIterator
8
+ from contextlib import asynccontextmanager
9
+ from pathlib import Path
10
+
11
+ import yaml
12
+ import uvicorn
13
+ from fastapi import (
14
+ Body,
15
+ FastAPI,
16
+ HTTPException,
17
+ Query,
18
+ Request,
19
+ Response,
20
+ WebSocket,
21
+ WebSocketDisconnect,
22
+ )
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from openenv.core.env_server.http_server import serialize_observation
25
+ from openenv.core.env_server.types import (
26
+ HealthResponse,
27
+ HealthStatus,
28
+ ResetRequest,
29
+ ResetResponse,
30
+ StepRequest,
31
+ StepResponse,
32
+ WSCloseMessage,
33
+ WSErrorCode,
34
+ WSErrorResponse,
35
+ WSObservationResponse,
36
+ WSResetMessage,
37
+ WSStateMessage,
38
+ WSStateResponse,
39
+ WSStepMessage,
40
+ )
41
+ from pydantic import ValidationError
42
+
43
+ from env_loader import load_env
44
+ from models import DataOpsAction, DataOpsObservation, DataOpsState
45
+ from server.dataops_env_environment import DataOpsEnvironment
46
+ from server.grading import evaluate_task
47
+ from server.session_manager import EnvironmentSessionManager
48
+ from server.task_specs import TASK_IDS, task_manifest_entries
49
+
50
+ # Repo root must be on sys.path (e.g. run `uv run python -m server.app` or uvicorn from project root).
51
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
52
+ SERVER_DIR = Path(__file__).resolve().parent
53
+
54
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(name)s | %(message)s")
55
+ logger = logging.getLogger(__name__)
56
+
57
+ load_env()
58
+
59
+ SESSION_COOKIE_NAME = "dataops_session_id"
60
+ SESSION_HEADER_NAME = "X-Session-ID"
61
+ MAX_HTTP_SESSIONS = int(os.getenv("MAX_HTTP_SESSIONS", "128"))
62
+ HTTP_SESSION_TIMEOUT_S = float(os.getenv("HTTP_SESSION_TIMEOUT_S", "1200"))
63
+ MAX_WS_SESSIONS = max(1, int(os.getenv("MAX_WS_SESSIONS", "64")))
64
+ ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "").strip()
65
+ COOKIE_SECURE = os.getenv("COOKIE_SECURE", "").lower() in {"1", "true", "yes"}
66
+
67
+
68
+ def _public_grader_details_enabled() -> bool:
69
+ """Read at request time so env / tests can control visibility without stale import-time state."""
70
+ v = os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower()
71
+ return v in {"1", "true", "yes"}
72
+
73
+ _ws_active_sessions = 0
74
+ _ws_session_lock = asyncio.Lock()
75
+
76
+ session_manager = EnvironmentSessionManager(
77
+ max_sessions=MAX_HTTP_SESSIONS,
78
+ session_timeout_s=HTTP_SESSION_TIMEOUT_S,
79
+ )
80
+
81
+
82
+ @asynccontextmanager
83
+ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
84
+ logger.info("DataOpsEnv starting.")
85
+ yield
86
+ session_manager.close_all()
87
+ logger.info("DataOpsEnv shutting down.")
88
+
89
+
90
+ app = FastAPI(
91
+ title="DataOpsEnv",
92
+ description="Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant).",
93
+ version="1.0.0",
94
+ lifespan=lifespan,
95
+ )
96
+
97
+
98
+ def _cors_allow_origins() -> list[str]:
99
+ configured = os.getenv("CORS_ALLOW_ORIGINS", "").strip()
100
+
101
+ if not configured:
102
+ return []
103
+
104
+ if configured == "*":
105
+ return ["*"]
106
+
107
+ return [item.strip() for item in configured.split(",") if item.strip()]
108
+
109
+
110
+ app.add_middleware(
111
+ CORSMiddleware,
112
+ allow_origins=_cors_allow_origins(),
113
+ allow_methods=["*"],
114
+ allow_headers=["*"],
115
+ )
116
+
117
+
118
+ def _load_manifest() -> dict:
119
+ yaml_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "openenv.yaml")
120
+ try:
121
+ with open(yaml_path, encoding="utf-8") as f:
122
+ return yaml.safe_load(f) or {}
123
+ except FileNotFoundError:
124
+ return {}
125
+
126
+
127
+ def _load_yaml_tasks() -> list[dict]:
128
+ manifest = _load_manifest()
129
+ tasks = manifest.get("tasks")
130
+ if isinstance(tasks, list) and tasks:
131
+ manifest_ids = [str(item.get("id", "")) for item in tasks]
132
+ if manifest_ids == TASK_IDS:
133
+ return tasks
134
+ return task_manifest_entries()
135
+
136
+
137
+ def _wrap_obs(obs: DataOpsObservation) -> dict:
138
+ """Serialise an observation to the standard OpenEnv response dict."""
139
+ return obs.model_dump()
140
+
141
+
142
+ def _get_session_id(request: Request) -> str | None:
143
+ header_value = request.headers.get(SESSION_HEADER_NAME)
144
+ if header_value:
145
+ return header_value.strip() or None
146
+ cookie_value = request.cookies.get(SESSION_COOKIE_NAME)
147
+ if cookie_value:
148
+ return cookie_value.strip() or None
149
+ return None
150
+
151
+
152
+ def _attach_session(response: Response, session_id: str) -> None:
153
+ response.set_cookie(
154
+ key=SESSION_COOKIE_NAME,
155
+ value=session_id,
156
+ httponly=True,
157
+ samesite="lax",
158
+ secure=COOKIE_SECURE,
159
+ max_age=int(HTTP_SESSION_TIMEOUT_S),
160
+ )
161
+ response.headers[SESSION_HEADER_NAME] = session_id
162
+
163
+
164
+ def _require_active_env(request: Request) -> tuple[str, DataOpsEnvironment]:
165
+ session_id, env = session_manager.get_session(_get_session_id(request))
166
+ if session_id is None or env is None:
167
+ raise HTTPException(400, "No active episode. Call /reset first.")
168
+ return session_id, env
169
+
170
+
171
+ def _ws_error_payload(message: str, code: WSErrorCode) -> str:
172
+ return WSErrorResponse(
173
+ data={
174
+ "message": message,
175
+ "code": code.value,
176
+ }
177
+ ).model_dump_json()
178
+
179
+
180
+ def _require_admin(request: Request) -> None:
181
+ if not ADMIN_API_KEY:
182
+ return
183
+ if request.headers.get("X-Admin-Key", "") != ADMIN_API_KEY:
184
+ raise HTTPException(403, "Missing or invalid admin key.")
185
+
186
+
187
+ def _request_is_admin(request: Request) -> bool:
188
+ return bool(ADMIN_API_KEY) and request.headers.get("X-Admin-Key", "") == ADMIN_API_KEY
189
+
190
+
191
+ def _format_grader_response(grade: dict, request: Request) -> dict:
192
+ if _public_grader_details_enabled() or _request_is_admin(request):
193
+ return grade
194
+ return {"task_id": grade.get("task_id"), "score": grade.get("score")}
195
+
196
+
197
+ async def _try_acquire_ws_slot() -> bool:
198
+ global _ws_active_sessions
199
+
200
+ async with _ws_session_lock:
201
+ if _ws_active_sessions >= MAX_WS_SESSIONS:
202
+ return False
203
+ _ws_active_sessions += 1
204
+ return True
205
+
206
+
207
+ async def _release_ws_slot() -> None:
208
+ global _ws_active_sessions
209
+
210
+ async with _ws_session_lock:
211
+ _ws_active_sessions = max(0, _ws_active_sessions - 1)
212
+
213
+
214
+ @app.get("/health", response_model=HealthResponse)
215
+ def health_endpoint():
216
+ return HealthResponse(status=HealthStatus.HEALTHY)
217
+
218
+
219
+ @app.get("/metadata")
220
+ def metadata_endpoint():
221
+ manifest = _load_manifest()
222
+ return {
223
+ "name": manifest.get("name", "dataops_env"),
224
+ "description": manifest.get(
225
+ "description",
226
+ (
227
+ "Enterprise data pipeline remediation environment. "
228
+ "Agents debug data streams, fix scripts, and send email reports."
229
+ ),
230
+ ),
231
+ "version": manifest.get("version", "1.0.0"),
232
+ "task_count": len(_load_yaml_tasks()),
233
+ }
234
+
235
+
236
+ @app.get("/schema")
237
+ def schema_endpoint():
238
+ return {
239
+ "action": DataOpsAction.model_json_schema(),
240
+ "observation": DataOpsObservation.model_json_schema(),
241
+ "state": DataOpsState.model_json_schema(),
242
+ }
243
+
244
+
245
+ @app.post("/mcp")
246
+ def mcp_endpoint(body: dict = Body(default_factory=dict)):
247
+ method = body.get("method", "")
248
+ req_id = body.get("id")
249
+
250
+ if method == "tools/list":
251
+ tools = [
252
+ {"name": atype, "description": f"Execute a {atype} action."}
253
+ for atype in [
254
+ "ExecuteSQL",
255
+ "ReadFile",
256
+ "WriteFile",
257
+ "RunScript",
258
+ "SendEmail",
259
+ ]
260
+ ]
261
+ return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}}
262
+
263
+ return {
264
+ "jsonrpc": "2.0",
265
+ "id": req_id,
266
+ "error": {"code": -32601, "message": "Method not found"},
267
+ }
268
+
269
+
270
+ @app.websocket("/ws")
271
+ async def websocket_endpoint(websocket: WebSocket):
272
+ await websocket.accept()
273
+ acquired_slot = await _try_acquire_ws_slot()
274
+ if not acquired_slot:
275
+ await websocket.send_text(
276
+ _ws_error_payload(
277
+ "WebSocket session capacity reached.",
278
+ WSErrorCode.CAPACITY_REACHED,
279
+ )
280
+ )
281
+ await websocket.close(code=1013)
282
+ return
283
+
284
+ env = DataOpsEnvironment()
285
+
286
+ try:
287
+ while True:
288
+ raw_message = await websocket.receive_text()
289
+
290
+ try:
291
+ message_dict = json.loads(raw_message)
292
+ except json.JSONDecodeError:
293
+ await websocket.send_text(
294
+ _ws_error_payload("Invalid JSON payload.", WSErrorCode.INVALID_JSON)
295
+ )
296
+ continue
297
+
298
+ message_type = message_dict.get("type", "")
299
+
300
+ try:
301
+ if message_type == "reset":
302
+ message = WSResetMessage(**message_dict)
303
+ observation = env.reset(**message.data)
304
+ response = WSObservationResponse(
305
+ data=serialize_observation(observation)
306
+ )
307
+ elif message_type == "step":
308
+ message = WSStepMessage(**message_dict)
309
+ action = DataOpsAction(**message.data)
310
+ observation = env.step(action)
311
+ response = WSObservationResponse(
312
+ data=serialize_observation(observation)
313
+ )
314
+ elif message_type == "state":
315
+ WSStateMessage(**message_dict)
316
+ response = WSStateResponse(data=env.state.model_dump())
317
+ elif message_type == "close":
318
+ WSCloseMessage(**message_dict)
319
+ break
320
+ else:
321
+ await websocket.send_text(
322
+ _ws_error_payload(
323
+ f"Unknown message type: {message_type}",
324
+ WSErrorCode.UNKNOWN_TYPE,
325
+ )
326
+ )
327
+ continue
328
+
329
+ await websocket.send_text(response.model_dump_json())
330
+ except ValidationError:
331
+ await websocket.send_text(
332
+ _ws_error_payload(
333
+ "Validation error while handling the WebSocket message.",
334
+ WSErrorCode.VALIDATION_ERROR,
335
+ )
336
+ )
337
+ except Exception:
338
+ logger.exception("WebSocket execution error")
339
+ await websocket.send_text(
340
+ _ws_error_payload(
341
+ "Execution error while handling the WebSocket message.",
342
+ WSErrorCode.EXECUTION_ERROR,
343
+ )
344
+ )
345
+ except WebSocketDisconnect:
346
+ logger.debug("WebSocket client disconnected.")
347
+ finally:
348
+ env.close()
349
+ await _release_ws_slot()
350
+
351
+
352
+ @app.post("/reset", response_model=ResetResponse)
353
+ def reset_endpoint(
354
+ request: Request,
355
+ response: Response,
356
+ task_id: str = Query("task_1_easy_anomaly", description="Task to initialise."),
357
+ body: ResetRequest = Body(default_factory=ResetRequest),
358
+ ):
359
+ if task_id not in TASK_IDS:
360
+ raise HTTPException(400, f"Invalid task_id. Choose from: {TASK_IDS}")
361
+ session_id = _get_session_id(request)
362
+ resolved_session_id, _env, obs = session_manager.reset_session(
363
+ task_id=task_id,
364
+ seed=body.seed,
365
+ episode_id=body.episode_id,
366
+ session_id=session_id,
367
+ )
368
+ _attach_session(response, resolved_session_id)
369
+ return ResetResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
370
+
371
+
372
+ @app.post("/step", response_model=StepResponse)
373
+ def step_endpoint(request: Request, response: Response, body: StepRequest):
374
+ try:
375
+ action = DataOpsAction(**body.action)
376
+ except ValidationError as e:
377
+ raise HTTPException(422, f"Invalid action: {e}") from e
378
+ session_id, env = _require_active_env(request)
379
+ _attach_session(response, session_id)
380
+ obs = env.step(action, timeout_s=body.timeout_s)
381
+ return StepResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
382
+
383
+
384
+ @app.get("/state", response_model=DataOpsState)
385
+ def state_endpoint(request: Request, response: Response):
386
+ session_id, env = _require_active_env(request)
387
+ _attach_session(response, session_id)
388
+ return env.state
389
+
390
+
391
+ @app.get("/tasks")
392
+ def tasks_endpoint():
393
+ return {
394
+ "tasks": _load_yaml_tasks(),
395
+ "action_schema": DataOpsAction.model_json_schema(),
396
+ "observation_schema": DataOpsObservation.model_json_schema(),
397
+ "state_schema": DataOpsState.model_json_schema(),
398
+ }
399
+
400
+
401
+ @app.get("/grader")
402
+ def grader_current_endpoint(request: Request, response: Response):
403
+ """Grade the current episode (uses active task_id from state)."""
404
+ session_id, env = _require_active_env(request)
405
+ _attach_session(response, session_id)
406
+ task_id = env.state.task_id
407
+ if not task_id:
408
+ raise HTTPException(400, "No active episode. Call /reset first.")
409
+ return _format_grader_response(evaluate_task(task_id, env), request)
410
+
411
+
412
+ @app.get("/grader/{task_id}")
413
+ def grader_endpoint(task_id: str, request: Request, response: Response):
414
+ if task_id not in TASK_IDS:
415
+ raise HTTPException(404, f"Unknown task: {task_id}")
416
+ session_id, env = _require_active_env(request)
417
+ _attach_session(response, session_id)
418
+ active_task_id = env.state.task_id
419
+ if active_task_id and active_task_id != task_id:
420
+ raise HTTPException(
421
+ 400,
422
+ f"Active episode belongs to task '{active_task_id}'. Reset the requested task first.",
423
+ )
424
+ return _format_grader_response(evaluate_task(task_id, env), request)
425
+
426
+
427
+ @app.post("/baseline")
428
+ def baseline_endpoint(request: Request, body: dict = Body(default_factory=dict)):
429
+ """Run inference.py (OpenAI tool-calling agent) against all tasks; same entrypoint as local baseline."""
430
+ _require_admin(request)
431
+ if not (
432
+ os.environ.get("API_KEY", "").strip()
433
+ or os.environ.get("HF_TOKEN", "").strip()
434
+ ):
435
+ raise HTTPException(
436
+ 503,
437
+ "API_KEY or HF_TOKEN must be set on the server process to run POST /baseline.",
438
+ )
439
+ script_path = PROJECT_ROOT / "inference.py"
440
+ if not script_path.is_file():
441
+ raise HTTPException(500, "inference.py missing from project root.")
442
+
443
+ port = int(os.getenv("PORT", "7860"))
444
+ timeout_s = HTTP_SESSION_TIMEOUT_S
445
+
446
+ env = {
447
+ **os.environ,
448
+ "ENV_BASE_URL": os.getenv("ENV_BASE_URL", f"http://127.0.0.1:{port}"),
449
+ }
450
+ command = [sys.executable, str(script_path), "--json-scores"]
451
+ if body.get("seed") is not None:
452
+ command.extend(["--seed", str(int(body["seed"]))])
453
+ if body.get("max_turns") is not None:
454
+ command.extend(["--max-turns", str(int(body["max_turns"]))])
455
+ for task_id in body.get("task_ids", []) or []:
456
+ if task_id in TASK_IDS:
457
+ command.extend(["--task", str(task_id)])
458
+
459
+ try:
460
+ proc = subprocess.run(
461
+ command,
462
+ cwd=str(PROJECT_ROOT),
463
+ capture_output=True,
464
+ text=True,
465
+ timeout=timeout_s,
466
+ env=env,
467
+ )
468
+ except subprocess.TimeoutExpired:
469
+ raise HTTPException(
470
+ 504, f"Baseline exceeded HTTP_SESSION_TIMEOUT_S ({timeout_s}s)."
471
+ ) from None
472
+
473
+ if proc.returncode != 0:
474
+ tail = (proc.stderr or proc.stdout or "")[-6000:]
475
+ logger.error("inference.py failed rc=%s stderr=%s", proc.returncode, tail[:500])
476
+ raise HTTPException(
477
+ 502,
478
+ {"message": "inference.py exited with an error.", "detail": tail},
479
+ )
480
+
481
+ lines = [ln.strip() for ln in (proc.stdout or "").splitlines() if ln.strip()]
482
+ parsed = None
483
+ for line in reversed(lines):
484
+ try:
485
+ parsed = json.loads(line)
486
+ break
487
+ except json.JSONDecodeError:
488
+ continue
489
+ if not isinstance(parsed, dict) or "scores" not in parsed:
490
+ raise HTTPException(
491
+ 502,
492
+ {
493
+ "message": "Could not parse JSON scores from inference.py stdout.",
494
+ "stdout_tail": "\n".join(lines[-5:]),
495
+ },
496
+ )
497
+
498
+ return {
499
+ "message": "Model baseline completed via inference.py.",
500
+ "stdout": proc.stdout,
501
+ "stderr": proc.stderr,
502
+ "scores": parsed["scores"],
503
+ "grades": parsed.get("grades"),
504
+ "average": parsed.get("average"),
505
+ "model": parsed.get("model"),
506
+ "metadata": parsed.get("metadata"),
507
+ }
508
+
509
+
510
+ def main():
511
+ """Entry point for `dataops-env` script and `openenv serve`."""
512
+ host = os.getenv("HOST", "0.0.0.0")
513
+ port = int(os.getenv("PORT", "7860"))
514
+ reload = os.getenv("DEBUG", "").lower() in ("1", "true")
515
+ cwd = Path.cwd().resolve()
516
+ app_target = "app:app" if cwd == SERVER_DIR else "server.app:app"
517
+ app_dir = str(SERVER_DIR if app_target == "app:app" else PROJECT_ROOT)
518
+ uvicorn.run(
519
+ app_target,
520
+ host=host,
521
+ port=port,
522
+ reload=reload,
523
+ reload_dirs=[str(PROJECT_ROOT)] if reload else None,
524
+ ws="wsproto",
525
+ app_dir=app_dir,
526
+ )
527
+
528
+
529
+ if __name__ == "__main__":
530
+ main()
server/dataops_env_environment.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import shutil
8
+ import sqlite3
9
+ import textwrap
10
+ import threading
11
+ import time
12
+ import uuid
13
+ from copy import deepcopy
14
+ from typing import Any, Optional
15
+
16
+ from openenv.core.env_server import Environment
17
+ from pydantic import ValidationError
18
+
19
+ from data.init_db import WORKSPACE_ROOT, setup_workspace
20
+ from models import (
21
+ PAYLOAD_MODELS,
22
+ DataOpsAction,
23
+ DataOpsObservation,
24
+ DataOpsState,
25
+ ExecuteSQLPayload,
26
+ ReadFilePayload,
27
+ RunScriptPayload,
28
+ SendEmailPayload,
29
+ WriteFilePayload,
30
+ )
31
+
32
+ from .safe_exec import PythonRunResult, run_python_code, run_python_script
33
+ from .task_specs import (
34
+ TASK_ALLOWED_READ_FILES,
35
+ TASK_ALLOWED_RUN_FILES,
36
+ TASK_ALLOWED_WRITE_FILES,
37
+ TASK_EMAIL_ENABLED,
38
+ TASK_IDS,
39
+ TASK_SQL_POLICIES,
40
+ TaskScenarioBundle,
41
+ build_task_scenario,
42
+ normalize_task_3_rows,
43
+ report_matches_expected,
44
+ task_3_data_matches_expected,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ _SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL)
50
+ _SQL_STRING_RE = re.compile(r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"")
51
+ _SQL_TABLE_REF_RE = re.compile(
52
+ r"\b(?:from|join|update|into|delete\s+from)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
53
+ re.IGNORECASE,
54
+ )
55
+ _SQL_CTE_NAME_RE = re.compile(
56
+ r"(?:\bwith\b|,)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s+as\s*\(",
57
+ re.IGNORECASE,
58
+ )
59
+
60
+ MAX_STEPS = 15
61
+ MAX_SQL_ROWS = 500
62
+ MAX_FILE_SIZE = 1_000_000
63
+ DEFAULT_ACTION_TIMEOUT_S = 10.0
64
+ MAX_ACTION_TIMEOUT_S = 30.0
65
+ MAX_STDOUT_CHARS = 50_000
66
+ MAX_STDERR_CHARS = 10_000
67
+
68
+ PENALTY_FAILURE = -0.03
69
+ PENALTY_DESTRUCTIVE = -0.20
70
+ PENALTY_REPEAT = -0.08
71
+ PENALTY_DISALLOWED_TOOL_UNIT = -0.04
72
+
73
+ # Keep milestone bonuses small so the terminal grader remains the dominant signal.
74
+ REWARD_EVENT_VALUES = {
75
+ "t1_inspected_corruption": 0.05,
76
+ "t1_exact_cleanup": 0.04,
77
+ "t2_read_source": 0.04,
78
+ "t2_candidate_compiles": 0.02,
79
+ "t2_verified_fix": 0.03,
80
+ "t3_nonempty_select": 0.03,
81
+ "t3_matching_sql": 0.03,
82
+ "t3_read_formatter_source": 0.02,
83
+ "t3_report_data_verified": 0.03,
84
+ "t3_formatter_compiles": 0.02,
85
+ "t3_report_generated": 0.03,
86
+ "t3_email_verified": 0.02,
87
+ }
88
+ PENALTY_EVENTS = {
89
+ "destructive_sql": PENALTY_DESTRUCTIVE,
90
+ "multiple_emails": -0.08,
91
+ "t2_run_before_read": -0.05,
92
+ "t2_write_before_read": -0.05,
93
+ }
94
+
95
+
96
+ class DataOpsEnvironment(Environment[DataOpsAction, DataOpsObservation, DataOpsState]):
97
+ """Enterprise data pipeline remediation environment (OpenEnv-compliant)."""
98
+
99
+ SUPPORTS_CONCURRENT_SESSIONS = True
100
+
101
+ def __init__(self) -> None:
102
+ self._workspace_dir = os.path.join(WORKSPACE_ROOT, "sessions", uuid.uuid4().hex)
103
+ self._db_path = os.path.join(self._workspace_dir, "mock_warehouse.db")
104
+ self._state = DataOpsState()
105
+ self._scenario: TaskScenarioBundle = build_task_scenario(
106
+ "task_1_easy_anomaly", seed=0
107
+ )
108
+ self._evidence: dict[str, Any] = {}
109
+ self._pending_events: list[str] = []
110
+ self.email_outbox: list[dict[str, str]] = []
111
+ self._last_action_key: Optional[str] = None
112
+ self._milestones: set[str] = set()
113
+ self._grader_score = 0.0
114
+ self._disallowed_tool_attempts = 0
115
+ self._lock = threading.Lock()
116
+
117
+ def reset(
118
+ self,
119
+ seed: Optional[int] = None,
120
+ episode_id: Optional[str] = None,
121
+ **kwargs: Any,
122
+ ) -> DataOpsObservation:
123
+ task_id: str = kwargs.get("task_id", "task_1_easy_anomaly")
124
+ if task_id not in TASK_IDS:
125
+ raise ValueError(f"Unknown task_id: {task_id}")
126
+
127
+ with self._lock:
128
+ self._scenario = build_task_scenario(task_id, seed=seed)
129
+ self._db_path = setup_workspace(
130
+ self._workspace_dir,
131
+ scenario=self._scenario,
132
+ )
133
+ self.email_outbox.clear()
134
+ self._last_action_key = None
135
+ self._milestones.clear()
136
+ self._pending_events = []
137
+ self._disallowed_tool_attempts = 0
138
+ self._evidence = self._initial_evidence()
139
+ self._state = DataOpsState(
140
+ episode_id=episode_id or str(uuid.uuid4()),
141
+ step_count=0,
142
+ task_id=task_id,
143
+ task_description=self._scenario.description,
144
+ max_steps=MAX_STEPS,
145
+ seed=self._scenario.seed,
146
+ )
147
+ self._grader_score = self._current_task_score()
148
+
149
+ return DataOpsObservation(
150
+ status="success",
151
+ done=False,
152
+ reward=0.0,
153
+ message=f"Environment reset. Task: {self._scenario.description}",
154
+ step_count=0,
155
+ max_steps=MAX_STEPS,
156
+ )
157
+
158
+ def step(
159
+ self, action: DataOpsAction, timeout_s: Optional[float] = None, **kwargs: Any
160
+ ) -> DataOpsObservation:
161
+ del kwargs
162
+ with self._lock:
163
+ return self._step_locked(action, timeout_s)
164
+
165
+ @property
166
+ def state(self) -> DataOpsState:
167
+ return self._state.model_copy()
168
+
169
+ @property
170
+ def scenario(self) -> TaskScenarioBundle:
171
+ return self._scenario
172
+
173
+ @property
174
+ def evidence(self) -> dict[str, Any]:
175
+ return deepcopy(self._evidence)
176
+
177
+ @property
178
+ def workspace_dir(self) -> str:
179
+ return self._workspace_dir
180
+
181
+ @property
182
+ def db_path(self) -> str:
183
+ return self._db_path
184
+
185
+ def close(self) -> None:
186
+ if os.path.isdir(self._workspace_dir):
187
+ shutil.rmtree(self._workspace_dir, ignore_errors=True)
188
+
189
+ def _step_locked(
190
+ self, action: DataOpsAction, timeout_s: Optional[float]
191
+ ) -> DataOpsObservation:
192
+ if self._state.done:
193
+ return self._obs(
194
+ "error", "Episode is over. Call /reset to start a new one.", done=True
195
+ )
196
+
197
+ model_cls = PAYLOAD_MODELS.get(action.action_type)
198
+ if not model_cls:
199
+ return self._obs("error", f"Unknown action_type: {action.action_type}")
200
+
201
+ try:
202
+ payload = model_cls(**action.payload)
203
+ except ValidationError as exc:
204
+ return self._obs(
205
+ "error",
206
+ f"Invalid payload: {exc.error_count()} validation error(s).",
207
+ )
208
+
209
+ self._pending_events = []
210
+ obs = self._dispatch(action.action_type, payload, timeout_s)
211
+
212
+ reward = self._compute_reward(action, obs)
213
+ self._state.step_count += 1
214
+ self._state.cumulative_reward += reward
215
+ self._state.actions_taken.append(action.action_type)
216
+ self._state.emails_sent = len(self.email_outbox)
217
+
218
+ done = self._state.step_count >= MAX_STEPS or self._task_completed()
219
+ self._state.done = done
220
+
221
+ obs.reward = round(reward, 4)
222
+ obs.done = done
223
+ obs.step_count = self._state.step_count
224
+ obs.max_steps = MAX_STEPS
225
+ return obs
226
+
227
+ def _dispatch(
228
+ self, action_type: str, payload: Any, timeout_s: Optional[float]
229
+ ) -> DataOpsObservation:
230
+ handlers = {
231
+ "ExecuteSQL": self._handle_sql,
232
+ "ReadFile": self._handle_read,
233
+ "WriteFile": self._handle_write,
234
+ "RunScript": self._handle_run,
235
+ "SendEmail": self._handle_email,
236
+ }
237
+ return handlers[action_type](payload, timeout_s)
238
+
239
+ def _handle_sql(
240
+ self, payload: ExecuteSQLPayload, timeout_s: Optional[float]
241
+ ) -> DataOpsObservation:
242
+ query = payload.query.strip()
243
+ while True:
244
+ q = query.rstrip()
245
+ if not q.endswith(";"):
246
+ break
247
+ query = q[:-1].rstrip()
248
+ statement_type = self._statement_type(query)
249
+ validation_error = self._validate_sql_action(query, statement_type)
250
+ if validation_error:
251
+ return self._obs("error", validation_error)
252
+
253
+ timeout = self._resolve_timeout(timeout_s)
254
+ deadline = time.monotonic() + timeout
255
+
256
+ try:
257
+ with sqlite3.connect(self._db_path) as conn:
258
+ conn.set_progress_handler(
259
+ lambda: 1 if time.monotonic() >= deadline else 0,
260
+ 1_000,
261
+ )
262
+ conn.row_factory = sqlite3.Row
263
+ cursor = conn.cursor()
264
+ cursor.execute(query)
265
+
266
+ if statement_type in {"SELECT", "WITH"}:
267
+ cols = [c[0] for c in cursor.description or []]
268
+ rows_raw = cursor.fetchmany(MAX_SQL_ROWS + 1)
269
+ if len(rows_raw) > MAX_SQL_ROWS:
270
+ return self._obs(
271
+ "error",
272
+ f"Result exceeds {MAX_SQL_ROWS} rows. Add a LIMIT clause.",
273
+ )
274
+ rows = [dict(zip(cols, row)) for row in rows_raw]
275
+ self._record_sql_select(query, rows)
276
+ return DataOpsObservation(
277
+ status="success",
278
+ sql_results=rows,
279
+ message=f"Query returned {len(rows)} rows.",
280
+ )
281
+
282
+ conn.commit()
283
+ self._record_sql_mutation(query, cursor.rowcount)
284
+ return self._obs("success", f"Rows affected: {cursor.rowcount}")
285
+ except sqlite3.Error as exc:
286
+ if "interrupted" in str(exc).lower():
287
+ return self._obs(
288
+ "error", f"SQL execution timed out ({timeout:.1f}s limit)."
289
+ )
290
+ logger.warning("SQL error: %s", exc)
291
+ msg = "SQL execution error. Check your query syntax."
292
+ if self._state.task_id == "task_3_hard_e2e" and re.search(
293
+ r"\bdate\b", query, re.IGNORECASE
294
+ ):
295
+ if "report_date" not in query.lower():
296
+ msg += " Hint: table `daily_reports` uses column `report_date` for the calendar date."
297
+ return self._obs("error", msg)
298
+
299
+ def _handle_read(
300
+ self, payload: ReadFilePayload, timeout_s: Optional[float]
301
+ ) -> DataOpsObservation:
302
+ del timeout_s
303
+ basename = os.path.basename(payload.filepath)
304
+ if not self._is_allowed_file(TASK_ALLOWED_READ_FILES, basename):
305
+ return self._obs(
306
+ "error", f"Reading {basename} is not allowed for this task."
307
+ )
308
+
309
+ safe_path = self._resolve_workspace_path(basename)
310
+ if safe_path is None:
311
+ return self._obs("error", "Resolved file path escapes the workspace.")
312
+ if not os.path.isfile(safe_path):
313
+ return self._obs("error", f"File not found: {basename}")
314
+ if os.path.getsize(safe_path) > MAX_FILE_SIZE:
315
+ return self._obs("error", "File too large to read.")
316
+
317
+ try:
318
+ with open(safe_path, encoding="utf-8") as f:
319
+ content = f.read(MAX_FILE_SIZE)
320
+ except OSError:
321
+ return self._obs("error", "Failed to read file.")
322
+
323
+ if (
324
+ self._state.task_id == "task_2_medium_syntax"
325
+ and basename == "broken_pipeline.py"
326
+ ):
327
+ self._evidence["task_2"]["read_source"] = True
328
+ self._record_event("t2_read_source")
329
+ if self._state.task_id == "task_3_hard_e2e" and basename == "format_report.py":
330
+ self._evidence["task_3"]["read_formatter_source"] = True
331
+ self._record_event("t3_read_formatter_source")
332
+ return DataOpsObservation(
333
+ status="success",
334
+ stdout=content,
335
+ message=f"Read {len(content)} chars from {basename}",
336
+ )
337
+
338
+ def _handle_write(
339
+ self, payload: WriteFilePayload, timeout_s: Optional[float]
340
+ ) -> DataOpsObservation:
341
+ del timeout_s
342
+ basename = os.path.basename(payload.filepath)
343
+ if not self._is_allowed_file(TASK_ALLOWED_WRITE_FILES, basename):
344
+ return self._obs(
345
+ "error", f"Writing {basename} is not allowed for this task."
346
+ )
347
+
348
+ if (
349
+ self._state.task_id == "task_2_medium_syntax"
350
+ and basename == "broken_pipeline.py"
351
+ ):
352
+ if not self._evidence["task_2"]["read_source"]:
353
+ self._pending_events.append("t2_write_before_read")
354
+
355
+ safe_path = self._resolve_workspace_path(basename)
356
+ if safe_path is None:
357
+ return self._obs("error", "Resolved file path escapes the workspace.")
358
+
359
+ try:
360
+ with open(safe_path, "w", encoding="utf-8") as f:
361
+ f.write(payload.content)
362
+ except OSError:
363
+ return self._obs("error", "Failed to write file.")
364
+
365
+ self._record_write_evidence(basename, payload.content)
366
+ return self._obs("success", f"Wrote {len(payload.content)} chars to {basename}")
367
+
368
+ def _handle_run(
369
+ self, payload: RunScriptPayload, timeout_s: Optional[float]
370
+ ) -> DataOpsObservation:
371
+ basename = os.path.basename(payload.filepath)
372
+ if not self._is_allowed_file(TASK_ALLOWED_RUN_FILES, basename):
373
+ return self._obs(
374
+ "error", f"Executing {basename} is not allowed for this task."
375
+ )
376
+
377
+ script_path = self._resolve_workspace_path(basename)
378
+ if script_path is None:
379
+ return self._obs("error", "Resolved script path escapes the workspace.")
380
+ if not os.path.isfile(script_path):
381
+ return self._obs("error", f"Script not found: {basename}")
382
+
383
+ if (
384
+ self._state.task_id == "task_2_medium_syntax"
385
+ and basename == "broken_pipeline.py"
386
+ ):
387
+ if not self._evidence["task_2"]["read_source"]:
388
+ self._pending_events.append("t2_run_before_read")
389
+
390
+ timeout = self._resolve_timeout(timeout_s)
391
+ try:
392
+ result = run_python_script(
393
+ basename,
394
+ cwd=self._workspace_dir,
395
+ args=list(payload.args),
396
+ timeout_s=timeout,
397
+ stdout_limit=MAX_STDOUT_CHARS,
398
+ stderr_limit=MAX_STDERR_CHARS,
399
+ )
400
+ except OSError:
401
+ return self._obs("error", "Failed to execute script.")
402
+
403
+ if result.timed_out:
404
+ return self._obs("error", f"Script timed out ({timeout:.1f}s limit).")
405
+
406
+ self._record_run_evidence(basename, payload.args, result)
407
+ status = "success" if result.returncode == 0 else "error"
408
+ return DataOpsObservation(
409
+ status=status,
410
+ stdout=(result.stdout or "")[:MAX_STDOUT_CHARS],
411
+ stderr=(result.stderr or "")[:MAX_STDERR_CHARS],
412
+ message=f"Exit code: {result.returncode}",
413
+ )
414
+
415
+ def _handle_email(
416
+ self, payload: SendEmailPayload, timeout_s: Optional[float]
417
+ ) -> DataOpsObservation:
418
+ del timeout_s
419
+ if self._state.task_id not in TASK_EMAIL_ENABLED:
420
+ self._disallowed_tool_attempts += 1
421
+ self._pending_events.append("disallowed_tool")
422
+ return self._obs(
423
+ "error",
424
+ "Email is not available for this task. Use read_file, write_file, and invoke_python only.",
425
+ )
426
+
427
+ email = {
428
+ "to_email": payload.to_email,
429
+ "subject": payload.subject,
430
+ "body": payload.body,
431
+ }
432
+ self.email_outbox.append(email)
433
+ self._record_email_evidence(email)
434
+ return DataOpsObservation(
435
+ status="success",
436
+ email_delivery_status=f"Queued for {payload.to_email}",
437
+ message=f"Email queued for delivery to {payload.to_email}",
438
+ )
439
+
440
+ def _compute_reward(self, action: DataOpsAction, obs: DataOpsObservation) -> float:
441
+ current_score = self._current_task_score()
442
+ reward = current_score - self._grader_score
443
+ self._grader_score = current_score
444
+
445
+ if obs.status != "success":
446
+ reward += PENALTY_FAILURE
447
+
448
+ action_key = (
449
+ f"{action.action_type}:"
450
+ f"{json.dumps(action.payload, sort_keys=True, ensure_ascii=True)}"
451
+ )
452
+ if action_key == self._last_action_key:
453
+ reward += PENALTY_REPEAT
454
+ self._last_action_key = action_key
455
+
456
+ for event in self._pending_events:
457
+ if event == "disallowed_tool":
458
+ reward += PENALTY_DISALLOWED_TOOL_UNIT * min(
459
+ self._disallowed_tool_attempts, 12
460
+ )
461
+ continue
462
+ if event in PENALTY_EVENTS:
463
+ reward += PENALTY_EVENTS[event]
464
+ continue
465
+ reward += self._award_event(event)
466
+ return reward
467
+
468
+ def _award_event(self, event: str) -> float:
469
+ if event in self._milestones:
470
+ return 0.0
471
+ self._milestones.add(event)
472
+ return REWARD_EVENT_VALUES.get(event, 0.0)
473
+
474
+ def _initial_evidence(self) -> dict[str, Any]:
475
+ return {
476
+ "task_1": {
477
+ "inspected_corrupted_rows": False,
478
+ "exact_cleanup": False,
479
+ "destructive_sql_attempted": False,
480
+ },
481
+ "task_2": {
482
+ "read_source": False,
483
+ "candidate_compiles": False,
484
+ "verified_fix": False,
485
+ },
486
+ "task_3": {
487
+ "matching_sql_executed": False,
488
+ "last_matching_sql_rows": [],
489
+ "read_formatter_source": False,
490
+ "report_data_matches_sql": False,
491
+ "formatter_compiles": False,
492
+ "format_output_matches_expected": False,
493
+ "last_formatter_output": "",
494
+ "email_matches_formatter_output": False,
495
+ "single_email_sent": True,
496
+ },
497
+ }
498
+
499
+ def _record_sql_select(self, query: str, rows: list[dict[str, Any]]) -> None:
500
+ if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
501
+ row_ids = {int(row.get("id")) for row in rows if row.get("id") is not None}
502
+ corrupted = set(self._scenario.task_1.corrupted_row_ids)
503
+ if row_ids & corrupted:
504
+ self._evidence["task_1"]["inspected_corrupted_rows"] = True
505
+ self._record_event("t1_inspected_corruption")
506
+
507
+ if self._scenario.task_3 and self._state.task_id == "task_3_hard_e2e":
508
+ normalised_rows = normalize_task_3_rows(rows, require_headcount=True)
509
+ expected_rows = list(self._scenario.task_3.expected_rows)
510
+ if task_3_data_matches_expected(
511
+ normalised_rows,
512
+ expected_rows,
513
+ require_headcount=True,
514
+ ):
515
+ self._evidence["task_3"]["matching_sql_executed"] = True
516
+ self._evidence["task_3"]["last_matching_sql_rows"] = normalised_rows
517
+ self._record_event("t3_matching_sql")
518
+ elif rows:
519
+ self._record_event("t3_nonempty_select")
520
+
521
+ def _record_sql_mutation(self, query: str, rowcount: int) -> None:
522
+ del rowcount
523
+ if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
524
+ exact_rows = self._current_transactions_rows()
525
+ expected_rows = list(self._scenario.task_1.expected_rows)
526
+ expected_by_id = {row["id"]: row for row in expected_rows}
527
+ actual_by_id = {row["id"]: row for row in exact_rows}
528
+ valid_rows_lost = any(
529
+ row_id not in actual_by_id for row_id in expected_by_id
530
+ )
531
+ valid_rows_changed = any(
532
+ actual_by_id[row_id] != expected_row
533
+ for row_id, expected_row in expected_by_id.items()
534
+ if row_id in actual_by_id
535
+ )
536
+ if exact_rows == expected_rows:
537
+ self._evidence["task_1"]["exact_cleanup"] = True
538
+ self._record_event("t1_exact_cleanup")
539
+ elif valid_rows_lost or valid_rows_changed:
540
+ self._evidence["task_1"]["destructive_sql_attempted"] = True
541
+ self._pending_events.append("destructive_sql")
542
+
543
+ def _record_write_evidence(self, basename: str, content: str) -> None:
544
+ if (
545
+ self._state.task_id == "task_2_medium_syntax"
546
+ and basename == "broken_pipeline.py"
547
+ ):
548
+ compiles = self._script_compiles(content, basename)
549
+ self._evidence["task_2"]["candidate_compiles"] = compiles
550
+ if compiles:
551
+ self._record_event("t2_candidate_compiles")
552
+ return
553
+
554
+ if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
555
+ return
556
+
557
+ task_3 = self._evidence["task_3"]
558
+ if basename == "report_data.json":
559
+ try:
560
+ payload = json.loads(content)
561
+ except json.JSONDecodeError:
562
+ task_3["report_data_matches_sql"] = False
563
+ return
564
+ if not isinstance(payload, list):
565
+ task_3["report_data_matches_sql"] = False
566
+ return
567
+ normalised_rows = normalize_task_3_rows(payload, require_headcount=True)
568
+ expected_rows = list(self._scenario.task_3.expected_rows)
569
+ last_sql_rows = task_3.get("last_matching_sql_rows", [])
570
+ matches_sql = bool(last_sql_rows) and normalised_rows == last_sql_rows
571
+ matches_expected = task_3_data_matches_expected(
572
+ normalised_rows,
573
+ expected_rows,
574
+ require_headcount=True,
575
+ )
576
+ task_3["report_data_matches_sql"] = matches_sql and matches_expected
577
+ if task_3["report_data_matches_sql"]:
578
+ self._record_event("t3_report_data_verified")
579
+ return
580
+
581
+ if basename == "format_report.py":
582
+ compiles = self._script_compiles(content, basename)
583
+ task_3["formatter_compiles"] = compiles
584
+ if compiles:
585
+ self._record_event("t3_formatter_compiles")
586
+
587
+ def _record_run_evidence(
588
+ self,
589
+ basename: str,
590
+ args: list[str],
591
+ result: PythonRunResult,
592
+ ) -> None:
593
+ if (
594
+ self._state.task_id == "task_2_medium_syntax"
595
+ and basename == "broken_pipeline.py"
596
+ ):
597
+ if result.returncode == 0 and self._task_2_candidate_is_functional():
598
+ self._evidence["task_2"]["verified_fix"] = True
599
+ self._record_event("t2_verified_fix")
600
+ return
601
+
602
+ if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
603
+ return
604
+ if basename != "format_report.py":
605
+ return
606
+
607
+ task_3 = self._evidence["task_3"]
608
+ stdout = (result.stdout or "").strip()
609
+ if (
610
+ result.returncode == 0
611
+ and self._task_3_args_reference_report_data(args)
612
+ and task_3.get("report_data_matches_sql")
613
+ and report_matches_expected(
614
+ stdout,
615
+ self._scenario.task_3.expected_rows,
616
+ self._scenario.task_3.target_date,
617
+ )
618
+ ):
619
+ task_3["format_output_matches_expected"] = True
620
+ task_3["last_formatter_output"] = stdout
621
+ self._record_event("t3_report_generated")
622
+
623
+ def _record_email_evidence(self, email: dict[str, str]) -> None:
624
+ if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
625
+ return
626
+
627
+ task_3 = self._evidence["task_3"]
628
+ if len(self.email_outbox) > 1:
629
+ task_3["single_email_sent"] = False
630
+ self._pending_events.append("multiple_emails")
631
+
632
+ if (
633
+ task_3.get("format_output_matches_expected")
634
+ and task_3.get("single_email_sent")
635
+ and email.get("to_email") == self._scenario.task_3.recipient
636
+ and email.get("subject") == self._scenario.task_3.subject
637
+ and email.get("body", "").strip()
638
+ == str(task_3.get("last_formatter_output", "")).strip()
639
+ ):
640
+ task_3["email_matches_formatter_output"] = True
641
+ self._record_event("t3_email_verified")
642
+
643
+ def _task_2_candidate_is_functional(self) -> bool:
644
+ if not self._scenario.task_2:
645
+ return False
646
+ wrapper = textwrap.dedent(
647
+ f"""
648
+ import importlib.util
649
+ import json
650
+
651
+ spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
652
+ module = importlib.util.module_from_spec(spec)
653
+ assert spec.loader is not None
654
+ spec.loader.exec_module(module)
655
+
656
+ cases = {json.dumps(self._scenario.task_2.hidden_cases)}
657
+ results = [module.process_data_stream(case) for case in cases]
658
+ print("__RESULT__=" + json.dumps(results))
659
+ """
660
+ )
661
+ try:
662
+ result = run_python_code(
663
+ wrapper,
664
+ cwd=self._workspace_dir,
665
+ timeout_s=DEFAULT_ACTION_TIMEOUT_S,
666
+ stdout_limit=MAX_STDOUT_CHARS,
667
+ stderr_limit=MAX_STDERR_CHARS,
668
+ )
669
+ except Exception:
670
+ return False
671
+ payload = next(
672
+ (
673
+ line[len("__RESULT__=") :]
674
+ for line in result.stdout.splitlines()
675
+ if line.startswith("__RESULT__=")
676
+ ),
677
+ "",
678
+ )
679
+ try:
680
+ parsed = json.loads(payload) if payload else None
681
+ except json.JSONDecodeError:
682
+ parsed = None
683
+ expected = [list(batch) for batch in self._scenario.task_2.hidden_expected]
684
+ return result.returncode == 0 and parsed == expected
685
+
686
+ def _task_3_args_reference_report_data(self, args: list[str]) -> bool:
687
+ if len(args) != 1:
688
+ return False
689
+
690
+ expected_path = self._resolve_workspace_path("report_data.json")
691
+ if expected_path is None:
692
+ return False
693
+
694
+ candidate = args[0]
695
+ if os.path.isabs(candidate):
696
+ resolved = os.path.realpath(candidate)
697
+ else:
698
+ resolved = os.path.realpath(os.path.join(self._workspace_dir, candidate))
699
+ return resolved == expected_path
700
+
701
+ def _current_task_score(self) -> float:
702
+ if not self._state.task_id:
703
+ return 0.0
704
+ try:
705
+ from .grading import evaluate_task
706
+
707
+ return float(evaluate_task(self._state.task_id, self).get("score", 0.0))
708
+ except Exception:
709
+ logger.exception(
710
+ "Failed to compute current grader score for reward shaping."
711
+ )
712
+ return self._grader_score
713
+
714
+ def _script_compiles(self, content: str, filename: str) -> bool:
715
+ try:
716
+ compile(content, filename, "exec")
717
+ except SyntaxError:
718
+ return False
719
+ return True
720
+
721
+ def _task_completed(self) -> bool:
722
+ if self._state.task_id == "task_1_easy_anomaly" and self._scenario.task_1:
723
+ return self._current_transactions_rows() == list(
724
+ self._scenario.task_1.expected_rows
725
+ )
726
+ if self._state.task_id == "task_2_medium_syntax":
727
+ # Terminal grader can be <1.0 even when verified_fix (visible/hidden/provenance split).
728
+ return self._grader_score >= 1.0
729
+ if self._state.task_id == "task_3_hard_e2e":
730
+ # Evidence flags can be partially true while component-weighted grader is still <1.0.
731
+ return self._grader_score >= 1.0
732
+ return False
733
+
734
+ def _current_transactions_rows(self) -> list[dict[str, Any]]:
735
+ with sqlite3.connect(self._db_path) as conn:
736
+ conn.row_factory = sqlite3.Row
737
+ rows = conn.execute(
738
+ "SELECT id, user_id, amount, status FROM transactions ORDER BY id"
739
+ ).fetchall()
740
+ return [
741
+ {
742
+ "id": int(row["id"]),
743
+ "user_id": int(row["user_id"]),
744
+ "amount": None
745
+ if row["amount"] is None
746
+ else round(float(row["amount"]), 2),
747
+ "status": str(row["status"]),
748
+ }
749
+ for row in rows
750
+ ]
751
+
752
+ def _record_event(self, event: str) -> None:
753
+ self._pending_events.append(event)
754
+
755
+ def _resolve_timeout(self, timeout_s: Optional[float]) -> float:
756
+ if timeout_s is None:
757
+ return DEFAULT_ACTION_TIMEOUT_S
758
+ return max(0.1, min(float(timeout_s), MAX_ACTION_TIMEOUT_S))
759
+
760
+ def _is_allowed_file(
761
+ self, allowed_registry: dict[str, frozenset[str]], basename: str
762
+ ) -> bool:
763
+ return basename in allowed_registry.get(self._state.task_id, frozenset())
764
+
765
+ def _resolve_workspace_path(self, basename: str) -> Optional[str]:
766
+ workspace_root = os.path.realpath(self._workspace_dir)
767
+ candidate = os.path.realpath(os.path.join(self._workspace_dir, basename))
768
+ if candidate == workspace_root:
769
+ return None
770
+ if not candidate.startswith(f"{workspace_root}{os.sep}"):
771
+ return None
772
+ return candidate
773
+
774
+ def _statement_type(self, query: str) -> str:
775
+ parts = query.split(None, 1)
776
+ return parts[0].upper() if parts else ""
777
+
778
+ def _validate_sql_action(self, query: str, statement_type: str) -> Optional[str]:
779
+ if not query:
780
+ return "SQL query cannot be empty."
781
+
782
+ policy = TASK_SQL_POLICIES.get(self._state.task_id)
783
+ if policy is None:
784
+ return "SQL is not available for the active task."
785
+ if statement_type not in policy.allowed_commands:
786
+ allowed = ", ".join(sorted(policy.allowed_commands))
787
+ return f"Only {allowed} statements are allowed for this task."
788
+
789
+ sanitized = self._strip_sql_literals_and_comments(query)
790
+ normalized = " ".join(sanitized.split())
791
+ lowered = normalized.lower()
792
+
793
+ if ";" in normalized:
794
+ return "Only a single SQL statement is allowed."
795
+ if any(
796
+ token in lowered
797
+ for token in ("pragma", "attach", "detach", "sqlite_", "alter ", "drop ")
798
+ ):
799
+ return "Query contains disallowed SQL constructs."
800
+ if statement_type == "DELETE" and not re.match(
801
+ rf"^delete\s+from\s+{re.escape(policy.required_table)}\s+where\b",
802
+ lowered,
803
+ ):
804
+ return f"DELETE statements must target '{policy.required_table}' with an explicit WHERE clause."
805
+
806
+ cte_names = self._extract_cte_names(normalized)
807
+ table_refs = self._extract_sql_table_refs(normalized)
808
+ if policy.required_table not in table_refs:
809
+ return f"Query must target the '{policy.required_table}' table."
810
+
811
+ allowed_refs = {policy.required_table, *cte_names}
812
+ disallowed = sorted(ref for ref in table_refs if ref not in allowed_refs)
813
+ if disallowed:
814
+ return f"Query references disallowed table(s): {', '.join(disallowed)}."
815
+ return None
816
+
817
+ def _strip_sql_literals_and_comments(self, query: str) -> str:
818
+ without_comments = _SQL_COMMENT_RE.sub(" ", query)
819
+ return _SQL_STRING_RE.sub("''", without_comments)
820
+
821
+ def _extract_cte_names(self, query: str) -> set[str]:
822
+ lowered = query.lower().lstrip()
823
+ if not lowered.startswith("with "):
824
+ return set()
825
+ return {match.group(1).lower() for match in _SQL_CTE_NAME_RE.finditer(query)}
826
+
827
+ def _extract_sql_table_refs(self, query: str) -> set[str]:
828
+ return {match.group(1).lower() for match in _SQL_TABLE_REF_RE.finditer(query)}
829
+
830
+ def _obs(
831
+ self, status: str, message: str, *, done: bool = False
832
+ ) -> DataOpsObservation:
833
+ return DataOpsObservation(
834
+ status=status,
835
+ message=message,
836
+ step_count=self._state.step_count,
837
+ max_steps=MAX_STEPS,
838
+ done=done,
839
+ )
server/grading.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Terminal graders for the seeded DataOpsEnv benchmark."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import json
7
+ import logging
8
+ import os
9
+ import sqlite3
10
+ from typing import Any
11
+
12
+ from server.dataops_env_environment import DataOpsEnvironment
13
+ from server.safe_exec import run_python_code, run_python_script
14
+ from server.task_specs import (
15
+ build_task_3_report,
16
+ normalize_task_2_output_rows,
17
+ normalize_task_3_rows,
18
+ report_matches_expected,
19
+ task_3_data_matches_expected,
20
+ task_3_semantic_match_fraction_rows,
21
+ task_3_semantic_match_fraction_text,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+ SCRIPT_TIMEOUT_S = 10
26
+ INTERNAL_STDOUT_LIMIT = 50_000
27
+ INTERNAL_STDERR_LIMIT = 10_000
28
+
29
+
30
+ def evaluate_task(task_id: str, env: DataOpsEnvironment) -> dict[str, Any]:
31
+ graders = {
32
+ "task_1_easy_anomaly": _grade_task_1,
33
+ "task_2_medium_syntax": _grade_task_2,
34
+ "task_3_hard_e2e": _grade_task_3,
35
+ }
36
+ grader = graders.get(task_id)
37
+ if grader is None:
38
+ return {"task_id": task_id, "score": 0.0, "details": {"error": "Unknown task"}}
39
+
40
+ score, details = grader(env)
41
+ return {"task_id": task_id, "score": round(score, 2), "details": details}
42
+
43
+
44
+ def _grade_task_1(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
45
+ if env.scenario.task_1 is None:
46
+ return 0.0, {"error": "Task 1 scenario missing."}
47
+
48
+ try:
49
+ actual_rows = _current_transactions_rows(env.db_path)
50
+ except Exception:
51
+ logger.exception("Task 1 grading error")
52
+ return 0.0, {"error": "Internal grading error."}
53
+
54
+ expected_rows = list(env.scenario.task_1.expected_rows)
55
+ corrupted_ids = set(env.scenario.task_1.corrupted_row_ids)
56
+ actual_ids = {row["id"] for row in actual_rows}
57
+ expected_ids = {row["id"] for row in expected_rows}
58
+ corrupted_remaining = sorted(actual_ids & corrupted_ids)
59
+ rewritten_corrupted = [
60
+ row for row in actual_rows if row["id"] in corrupted_ids and row["amount"] is not None
61
+ ]
62
+ valid_rows_intact = all(
63
+ any(actual == expected for actual in actual_rows) for expected in expected_rows
64
+ )
65
+
66
+ details: dict[str, Any] = {
67
+ "expected_row_ids": sorted(expected_ids),
68
+ "actual_row_ids": sorted(actual_ids),
69
+ "corrupted_row_ids": sorted(corrupted_ids),
70
+ "corrupted_remaining": corrupted_remaining,
71
+ "valid_rows_intact": valid_rows_intact,
72
+ }
73
+
74
+ if actual_rows == expected_rows:
75
+ details["reason"] = "Perfect - corrupted rows were deleted and all valid rows were preserved."
76
+ details["components"] = {
77
+ "exact_cleanup": {"score": 1.0, "max": 1.0, "passed": True},
78
+ }
79
+ return 1.0, details
80
+
81
+ if rewritten_corrupted:
82
+ details["reason"] = "Corrupted rows were rewritten instead of being deleted."
83
+ details["components"] = {
84
+ "exact_cleanup": {"score": 0.0, "max": 1.0, "passed": False},
85
+ }
86
+ return 0.0, details
87
+
88
+ if valid_rows_intact and corrupted_remaining:
89
+ fraction_removed = 1.0 - (len(corrupted_remaining) / max(len(corrupted_ids), 1))
90
+ score = round(0.25 * max(fraction_removed, 0.0), 4)
91
+ details["reason"] = "Some corrupted rows were removed, but cleanup is incomplete."
92
+ details["components"] = {
93
+ "partial_cleanup": {"score": score, "max": 0.25, "passed": False},
94
+ }
95
+ return score, details
96
+
97
+ details["reason"] = "The transaction table does not match the required cleaned state."
98
+ details["components"] = {
99
+ "exact_cleanup": {"score": 0.0, "max": 1.0, "passed": False},
100
+ }
101
+ return 0.0, details
102
+
103
+
104
+ def _grade_task_2(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
105
+ if env.scenario.task_2 is None:
106
+ return 0.0, {"error": "Task 2 scenario missing."}
107
+
108
+ script = os.path.join(env.workspace_dir, "broken_pipeline.py")
109
+ if not os.path.isfile(script):
110
+ return 0.0, {
111
+ "reason": "broken_pipeline.py not found.",
112
+ "components": {
113
+ "script_present": {"score": 0.0, "max": 1.0, "passed": False},
114
+ },
115
+ }
116
+
117
+ try:
118
+ with open(script, encoding="utf-8") as f:
119
+ source = f.read()
120
+ static = _inspect_task_2_source(source)
121
+ main_result = run_python_script(
122
+ "broken_pipeline.py",
123
+ cwd=env.workspace_dir,
124
+ args=[],
125
+ timeout_s=SCRIPT_TIMEOUT_S,
126
+ stdout_limit=INTERNAL_STDOUT_LIMIT,
127
+ stderr_limit=INTERNAL_STDERR_LIMIT,
128
+ )
129
+ visible_result = _run_task_2_case_check(
130
+ env.workspace_dir,
131
+ env.scenario.task_2.visible_batch,
132
+ env.scenario.task_2.visible_expected,
133
+ )
134
+ hidden_result = _run_task_2_hidden_tests(
135
+ env.workspace_dir,
136
+ env.scenario.task_2.hidden_cases,
137
+ env.scenario.task_2.hidden_expected,
138
+ )
139
+ except Exception:
140
+ logger.exception("Task 2 grading error")
141
+ return 0.0, {"error": "Internal grading error."}
142
+
143
+ if main_result.timed_out or visible_result["timed_out"] or hidden_result["timed_out"]:
144
+ return 0.0, {"reason": "Script timed out.", "components": {}}
145
+
146
+ hidden_score = round(0.60 * hidden_result["pass_fraction"], 4)
147
+ visible_score = 0.25 if visible_result["passed"] and main_result.returncode == 0 else 0.0
148
+ execution_score = 0.15 if env.evidence.get("task_2", {}).get("verified_fix") else 0.0
149
+ components: dict[str, Any] = {
150
+ "hidden_functional": {
151
+ "score": hidden_score,
152
+ "max": 0.60,
153
+ "passed": hidden_result["passed"],
154
+ },
155
+ "visible_pipeline": {
156
+ "score": visible_score,
157
+ "max": 0.25,
158
+ "passed": visible_result["passed"] and main_result.returncode == 0,
159
+ },
160
+ "execution_provenance": {
161
+ "score": execution_score,
162
+ "max": 0.15,
163
+ "passed": bool(env.evidence.get("task_2", {}).get("verified_fix")),
164
+ },
165
+ }
166
+ score = round(sum(component["score"] for component in components.values()), 4)
167
+ details = {
168
+ "main_exit_code": main_result.returncode,
169
+ "main_stdout": main_result.stdout[:500],
170
+ "main_stderr": main_result.stderr[:500],
171
+ "visible_batch_ok": visible_result["passed"],
172
+ "hidden_tests_passed": hidden_result["passed"],
173
+ "hidden_pass_fraction": hidden_result["pass_fraction"],
174
+ "hidden_case_passes": hidden_result["case_passes"],
175
+ "static_checks": static,
176
+ "components": components,
177
+ }
178
+
179
+ if score == 1.0:
180
+ details["reason"] = "Seeded hidden tests and the visible verification run both pass."
181
+ elif hidden_result["passed"] and main_result.returncode == 0:
182
+ details["reason"] = "The ETL transform is correct, but the agent never verified it through the run action."
183
+ elif hidden_result["pass_fraction"] > 0 and main_result.returncode == 0:
184
+ details["reason"] = "The repair improves the ETL transform, but it still fails some seeded cases."
185
+ elif hidden_result["pass_fraction"] > 0:
186
+ details["reason"] = "The core transform improved, but the runnable script entrypoint still drifts."
187
+ elif main_result.returncode == 0:
188
+ details["reason"] = "The script runs, but it does not yet produce the required normalized records."
189
+ else:
190
+ details["reason"] = "The repair is still incorrect or incomplete."
191
+ return score, details
192
+
193
+
194
+ def _grade_task_3(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
195
+ if env.scenario.task_3 is None:
196
+ return 0.0, {"error": "Task 3 scenario missing."}
197
+
198
+ scenario = env.scenario.task_3
199
+ evidence = env.evidence.get("task_3", {})
200
+ expected_rows = list(scenario.expected_rows)
201
+ expected_report = build_task_3_report(expected_rows, scenario.target_date)
202
+
203
+ report_data = _load_task_3_data(env.workspace_dir, expected_rows)
204
+ formatter = _run_task_3_formatter(env.workspace_dir, expected_rows, scenario.target_date)
205
+ email = _score_task_3_email(env, expected_report)
206
+
207
+ report_exact_and_proven = bool(
208
+ report_data["matches_expected"] and evidence.get("report_data_matches_sql")
209
+ )
210
+ formatter_exact_and_proven = bool(
211
+ formatter["matches_expected"] and evidence.get("format_output_matches_expected")
212
+ )
213
+
214
+ components: dict[str, Any] = {
215
+ "sql_provenance": {
216
+ "score": 0.20 if evidence.get("matching_sql_executed") else 0.0,
217
+ "max": 0.20,
218
+ "passed": bool(evidence.get("matching_sql_executed")),
219
+ },
220
+ "report_data": {
221
+ "score": 0.20 if report_exact_and_proven else 0.05 if report_data["matches_expected"] else 0.0,
222
+ "max": 0.20,
223
+ "passed": report_exact_and_proven,
224
+ },
225
+ "formatter": {
226
+ "score": 0.25 if formatter_exact_and_proven else 0.05 if formatter["runs"] else 0.0,
227
+ "max": 0.25,
228
+ "passed": formatter_exact_and_proven,
229
+ },
230
+ "email": {
231
+ "score": email["score"],
232
+ "max": 0.35,
233
+ "passed": email["passed"],
234
+ },
235
+ }
236
+ score = round(sum(component["score"] for component in components.values()), 4)
237
+ details: dict[str, Any] = {
238
+ "target_date": scenario.target_date,
239
+ "expected_recipient": scenario.recipient,
240
+ "expected_subject": scenario.subject,
241
+ "report_data": report_data["details"],
242
+ "formatter": formatter["details"],
243
+ "email": email["details"],
244
+ "evidence": evidence,
245
+ "components": components,
246
+ }
247
+
248
+ if score == 1.0:
249
+ details["reason"] = "Perfect - the seeded SQL slice, JSON output, formatter run, and final email all align."
250
+ elif score >= 0.55:
251
+ details["reason"] = "Strong progress - some of the seeded workflow is correct, but provenance is incomplete."
252
+ elif score > 0:
253
+ details["reason"] = "Partial progress - artifacts exist, but the end-to-end incident workflow is not proven."
254
+ else:
255
+ details["reason"] = "The seeded hard task is still unsolved."
256
+ return score, details
257
+
258
+
259
+ def _inspect_task_2_source(source: str) -> dict[str, Any]:
260
+ try:
261
+ tree = ast.parse(source)
262
+ except SyntaxError as exc:
263
+ return {"passed": False, "error": str(exc), "has_function": False}
264
+
265
+ functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
266
+ target = next((node for node in functions if node.name == "process_data_stream"), None)
267
+ passed = target is not None and len(target.args.args) == 1
268
+ return {"passed": passed, "has_function": target is not None}
269
+
270
+
271
+ def _run_task_2_case_check(
272
+ workspace_dir: str,
273
+ batch: tuple[dict[str, Any], ...],
274
+ expected: tuple[dict[str, Any], ...],
275
+ ) -> dict[str, Any]:
276
+ wrapper = f"""
277
+ import importlib.util
278
+ import json
279
+
280
+ spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
281
+ module = importlib.util.module_from_spec(spec)
282
+ assert spec.loader is not None
283
+ spec.loader.exec_module(module)
284
+
285
+ batch = {json.dumps(list(batch))}
286
+ results = module.process_data_stream(batch)
287
+ print("__RESULT__=" + json.dumps(results))
288
+ """
289
+ result = run_python_code(
290
+ wrapper,
291
+ cwd=workspace_dir,
292
+ timeout_s=SCRIPT_TIMEOUT_S,
293
+ stdout_limit=INTERNAL_STDOUT_LIMIT,
294
+ stderr_limit=INTERNAL_STDERR_LIMIT,
295
+ )
296
+ payload = next(
297
+ (
298
+ line[len("__RESULT__=") :]
299
+ for line in result.stdout.splitlines()
300
+ if line.startswith("__RESULT__=")
301
+ ),
302
+ "",
303
+ )
304
+ try:
305
+ parsed = json.loads(payload) if payload else None
306
+ except json.JSONDecodeError:
307
+ parsed = None
308
+ normalised = normalize_task_2_output_rows(parsed)
309
+ ok = result.returncode == 0 and normalised == list(expected)
310
+ return {
311
+ "passed": ok,
312
+ "timed_out": result.timed_out,
313
+ "stdout": result.stdout[:500],
314
+ "stderr": result.stderr[:500],
315
+ "actual": normalised,
316
+ }
317
+
318
+
319
+ def _run_task_2_hidden_tests(
320
+ workspace_dir: str,
321
+ hidden_cases: tuple[tuple[dict[str, Any], ...], ...],
322
+ hidden_expected: tuple[tuple[dict[str, Any], ...], ...],
323
+ ) -> dict[str, Any]:
324
+ wrapper = f"""
325
+ import importlib.util
326
+ import json
327
+
328
+ spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
329
+ module = importlib.util.module_from_spec(spec)
330
+ assert spec.loader is not None
331
+ spec.loader.exec_module(module)
332
+
333
+ cases = {json.dumps([list(batch) for batch in hidden_cases])}
334
+ results = [module.process_data_stream(case) for case in cases]
335
+ print("__RESULT__=" + json.dumps(results))
336
+ """
337
+ result = run_python_code(
338
+ wrapper,
339
+ cwd=workspace_dir,
340
+ timeout_s=SCRIPT_TIMEOUT_S,
341
+ stdout_limit=INTERNAL_STDOUT_LIMIT,
342
+ stderr_limit=INTERNAL_STDERR_LIMIT,
343
+ )
344
+ payload = next(
345
+ (
346
+ line[len("__RESULT__=") :]
347
+ for line in result.stdout.splitlines()
348
+ if line.startswith("__RESULT__=")
349
+ ),
350
+ "",
351
+ )
352
+ try:
353
+ parsed = json.loads(payload) if payload else None
354
+ except json.JSONDecodeError:
355
+ parsed = None
356
+ if not isinstance(parsed, list):
357
+ parsed = []
358
+
359
+ actual_batches = [
360
+ normalize_task_2_output_rows(batch)
361
+ for batch in parsed
362
+ ]
363
+ expected = [list(batch) for batch in hidden_expected]
364
+ case_passes = [
365
+ actual == expected_case
366
+ for actual, expected_case in zip(actual_batches, expected, strict=False)
367
+ ]
368
+ if len(case_passes) < len(expected):
369
+ case_passes.extend([False] * (len(expected) - len(case_passes)))
370
+ pass_fraction = (
371
+ sum(1 for passed in case_passes if passed) / len(expected)
372
+ if expected
373
+ else 0.0
374
+ )
375
+ return {
376
+ "passed": result.returncode == 0 and len(actual_batches) == len(expected) and all(case_passes),
377
+ "timed_out": result.timed_out,
378
+ "stdout": result.stdout[:500],
379
+ "stderr": result.stderr[:500],
380
+ "actual": actual_batches,
381
+ "case_passes": case_passes,
382
+ "pass_fraction": round(pass_fraction, 4),
383
+ }
384
+
385
+
386
+ def _load_task_3_data(
387
+ workspace_dir: str, expected_rows: list[dict[str, Any]]
388
+ ) -> dict[str, Any]:
389
+ report_json = os.path.join(workspace_dir, "report_data.json")
390
+ if not os.path.isfile(report_json):
391
+ return {
392
+ "matches_expected": False,
393
+ "details": {"exists": False, "reason": "report_data.json not found."},
394
+ }
395
+
396
+ try:
397
+ with open(report_json, encoding="utf-8") as f:
398
+ payload = json.load(f)
399
+ except (OSError, json.JSONDecodeError) as exc:
400
+ return {
401
+ "matches_expected": False,
402
+ "details": {"exists": True, "reason": str(exc)},
403
+ }
404
+
405
+ if not isinstance(payload, list):
406
+ return {
407
+ "matches_expected": False,
408
+ "details": {
409
+ "exists": True,
410
+ "reason": "report_data.json must contain a JSON list.",
411
+ },
412
+ }
413
+
414
+ rows = normalize_task_3_rows(payload, require_headcount=True)
415
+ matches_expected = bool(rows) and task_3_data_matches_expected(
416
+ rows,
417
+ expected_rows,
418
+ require_headcount=True,
419
+ )
420
+ semantic_fraction = task_3_semantic_match_fraction_rows(rows, expected_rows)
421
+ return {
422
+ "matches_expected": matches_expected,
423
+ "details": {
424
+ "exists": True,
425
+ "rows_valid": bool(rows),
426
+ "rows_match_expected": matches_expected,
427
+ "semantic_fraction": round(semantic_fraction, 4),
428
+ },
429
+ }
430
+
431
+
432
+ def _run_task_3_formatter(
433
+ workspace_dir: str,
434
+ expected_rows: list[dict[str, Any]],
435
+ target_date: str,
436
+ ) -> dict[str, Any]:
437
+ script = os.path.join(workspace_dir, "format_report.py")
438
+ if not os.path.isfile(script):
439
+ return {
440
+ "runs": False,
441
+ "matches_expected": False,
442
+ "details": {"reason": "format_report.py not found."},
443
+ }
444
+
445
+ try:
446
+ result = run_python_script(
447
+ "format_report.py",
448
+ cwd=workspace_dir,
449
+ args=["report_data.json"],
450
+ timeout_s=SCRIPT_TIMEOUT_S,
451
+ stdout_limit=INTERNAL_STDOUT_LIMIT,
452
+ stderr_limit=INTERNAL_STDERR_LIMIT,
453
+ )
454
+ except Exception as exc:
455
+ return {
456
+ "runs": False,
457
+ "matches_expected": False,
458
+ "details": {"reason": str(exc)},
459
+ }
460
+ if result.timed_out:
461
+ return {
462
+ "runs": False,
463
+ "matches_expected": False,
464
+ "details": {"reason": "Formatter timed out."},
465
+ }
466
+
467
+ stdout = (result.stdout or "").strip()
468
+ matches_expected = result.returncode == 0 and report_matches_expected(
469
+ stdout,
470
+ expected_rows,
471
+ target_date,
472
+ )
473
+ return {
474
+ "runs": result.returncode == 0,
475
+ "matches_expected": matches_expected,
476
+ "details": {
477
+ "exit_code": result.returncode,
478
+ "stdout": stdout[:500],
479
+ "stderr": (result.stderr or "")[:500],
480
+ "semantic_fraction": round(
481
+ task_3_semantic_match_fraction_text(stdout, expected_rows, target_date),
482
+ 4,
483
+ ),
484
+ },
485
+ }
486
+
487
+
488
+ def _score_task_3_email(
489
+ env: DataOpsEnvironment, expected_report: str
490
+ ) -> dict[str, Any]:
491
+ scenario = env.scenario.task_3
492
+ assert scenario is not None
493
+ evidence = env.evidence.get("task_3", {})
494
+ outbox = env.email_outbox
495
+ if not outbox:
496
+ return {
497
+ "score": 0.0,
498
+ "passed": False,
499
+ "details": {"reason": "No email sent."},
500
+ }
501
+
502
+ email = outbox[-1]
503
+ recipient_ok = email.get("to_email") == scenario.recipient
504
+ subject_ok = email.get("subject") == scenario.subject
505
+ body = str(email.get("body", "")).strip()
506
+ body_ok = body == expected_report.strip()
507
+ proven = bool(evidence.get("email_matches_formatter_output")) and len(outbox) == 1
508
+
509
+ score = 0.0
510
+ if recipient_ok:
511
+ score += 0.05
512
+ if subject_ok:
513
+ score += 0.05
514
+ if body_ok and proven:
515
+ score += 0.25
516
+
517
+ return {
518
+ "score": score,
519
+ "passed": score == 0.35,
520
+ "details": {
521
+ "emails_sent": len(outbox),
522
+ "recipient_ok": recipient_ok,
523
+ "subject_ok": subject_ok,
524
+ "body_ok": body_ok,
525
+ "proven": proven,
526
+ "semantic_fraction": round(
527
+ task_3_semantic_match_fraction_text(
528
+ body,
529
+ list(scenario.expected_rows),
530
+ scenario.target_date,
531
+ ),
532
+ 4,
533
+ ),
534
+ },
535
+ }
536
+
537
+
538
+ def _current_transactions_rows(db_path: str) -> list[dict[str, Any]]:
539
+ with sqlite3.connect(db_path) as conn:
540
+ conn.row_factory = sqlite3.Row
541
+ table_exists = conn.execute(
542
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
543
+ ).fetchone()
544
+ if not table_exists:
545
+ return []
546
+ rows = conn.execute(
547
+ "SELECT id, user_id, amount, status FROM transactions ORDER BY id"
548
+ ).fetchall()
549
+ return [
550
+ {
551
+ "id": int(row["id"]),
552
+ "user_id": int(row["user_id"]),
553
+ "amount": None if row["amount"] is None else round(float(row["amount"]), 2),
554
+ "status": str(row["status"]),
555
+ }
556
+ for row in rows
557
+ ]
server/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ fastapi>=0.115.0
3
+ starlette>=0.46.0,<0.52.0
4
+ uvicorn[standard]>=0.34.0
5
+ pydantic>=2.10.0
6
+ pyyaml>=6.0.2
7
+ openai>=1.60.0
8
+ requests>=2.32.0
9
+ wsproto>=1.3.2
10
+ python-dotenv>=1.0.0
server/safe_exec.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ import signal
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ from dataclasses import dataclass
11
+
12
+ DEFAULT_ADDRESS_SPACE_BYTES = 512 * 1024 * 1024
13
+ DEFAULT_FILE_BYTES = 2 * 1024 * 1024
14
+ DEFAULT_OPEN_FILES = 64
15
+ DEFAULT_PROCESSES = 32
16
+
17
+ _RUNNER_BOOTSTRAP = r"""
18
+ import json
19
+ import runpy
20
+ import sys
21
+
22
+ try:
23
+ import resource
24
+ except ImportError: # pragma: no cover
25
+ resource = None
26
+
27
+
28
+ def _set_limit(name, value):
29
+ if resource is None or not hasattr(resource, name):
30
+ return
31
+ limit = int(value)
32
+ try:
33
+ _, current_hard = resource.getrlimit(getattr(resource, name))
34
+ soft = min(limit, current_hard) if current_hard >= 0 else limit
35
+ resource.setrlimit(getattr(resource, name), (soft, current_hard))
36
+ except (OSError, ValueError):
37
+ return
38
+
39
+
40
+ config = json.loads(sys.argv[1])
41
+ _set_limit("RLIMIT_CORE", 0)
42
+ _set_limit("RLIMIT_CPU", config["cpu_seconds"])
43
+ _set_limit("RLIMIT_FSIZE", config["file_bytes"])
44
+ _set_limit("RLIMIT_NOFILE", config["open_files"])
45
+ _set_limit("RLIMIT_AS", config["address_space_bytes"])
46
+ _set_limit("RLIMIT_NPROC", config["processes"])
47
+
48
+ mode = config["mode"]
49
+ if mode == "script":
50
+ script = sys.argv[2]
51
+ sys.argv = sys.argv[2:]
52
+ runpy.run_path(script, run_name="__main__")
53
+ elif mode == "code":
54
+ sys.argv = ["-c"]
55
+ exec(config["code"], {"__name__": "__main__"})
56
+ else: # pragma: no cover
57
+ raise SystemExit(f"Unsupported execution mode: {mode}")
58
+ """
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class PythonRunResult:
63
+ returncode: int
64
+ stdout: str
65
+ stderr: str
66
+ timed_out: bool = False
67
+
68
+
69
+ def _safe_env(workspace_dir: str) -> dict[str, str]:
70
+ return {
71
+ "HOME": workspace_dir,
72
+ "TMPDIR": workspace_dir,
73
+ "LANG": "C.UTF-8",
74
+ "LC_ALL": "C.UTF-8",
75
+ "PATH": "",
76
+ "PYTHONDONTWRITEBYTECODE": "1",
77
+ "PYTHONHASHSEED": "0",
78
+ "PYTHONIOENCODING": "utf-8",
79
+ "PYTHONNOUSERSITE": "1",
80
+ }
81
+
82
+
83
+ def _limit_config(timeout_s: float) -> dict[str, int]:
84
+ return {
85
+ "cpu_seconds": max(1, int(math.ceil(timeout_s)) + 1),
86
+ "file_bytes": DEFAULT_FILE_BYTES,
87
+ "open_files": DEFAULT_OPEN_FILES,
88
+ "address_space_bytes": DEFAULT_ADDRESS_SPACE_BYTES,
89
+ "processes": DEFAULT_PROCESSES,
90
+ }
91
+
92
+
93
+ def _read_limited_text(handle, limit: int) -> str:
94
+ handle.seek(0)
95
+ data = handle.read(limit + 1)
96
+ if isinstance(data, bytes):
97
+ return data.decode("utf-8", errors="replace")[:limit]
98
+ return str(data)[:limit]
99
+
100
+
101
+ def _terminate_process(proc: subprocess.Popen[bytes]) -> None:
102
+ if proc.poll() is not None:
103
+ return
104
+ if os.name != "nt":
105
+ try:
106
+ os.killpg(proc.pid, signal.SIGKILL)
107
+ return
108
+ except ProcessLookupError:
109
+ return
110
+ proc.kill()
111
+
112
+
113
+ def _run_python_command(
114
+ config: dict[str, object],
115
+ *,
116
+ cwd: str,
117
+ argv: list[str],
118
+ timeout_s: float,
119
+ stdout_limit: int,
120
+ stderr_limit: int,
121
+ ) -> PythonRunResult:
122
+ command = [
123
+ sys.executable,
124
+ "-I",
125
+ "-B",
126
+ "-c",
127
+ _RUNNER_BOOTSTRAP,
128
+ json.dumps(config, ensure_ascii=True),
129
+ *argv,
130
+ ]
131
+ start_new_session = os.name != "nt"
132
+
133
+ with tempfile.TemporaryFile() as stdout_file, tempfile.TemporaryFile() as stderr_file:
134
+ proc = subprocess.Popen(
135
+ command,
136
+ cwd=cwd,
137
+ env=_safe_env(cwd),
138
+ stdin=subprocess.DEVNULL,
139
+ stdout=stdout_file,
140
+ stderr=stderr_file,
141
+ start_new_session=start_new_session,
142
+ )
143
+ timed_out = False
144
+ try:
145
+ proc.wait(timeout=timeout_s)
146
+ except subprocess.TimeoutExpired:
147
+ timed_out = True
148
+ _terminate_process(proc)
149
+ proc.wait()
150
+
151
+ return PythonRunResult(
152
+ returncode=proc.returncode if proc.returncode is not None else -1,
153
+ stdout=_read_limited_text(stdout_file, stdout_limit),
154
+ stderr=_read_limited_text(stderr_file, stderr_limit),
155
+ timed_out=timed_out,
156
+ )
157
+
158
+
159
+ def run_python_script(
160
+ script_name: str,
161
+ *,
162
+ cwd: str,
163
+ args: list[str],
164
+ timeout_s: float,
165
+ stdout_limit: int,
166
+ stderr_limit: int,
167
+ ) -> PythonRunResult:
168
+ config = {"mode": "script", **_limit_config(timeout_s)}
169
+ return _run_python_command(
170
+ config,
171
+ cwd=cwd,
172
+ argv=[script_name, *args],
173
+ timeout_s=timeout_s,
174
+ stdout_limit=stdout_limit,
175
+ stderr_limit=stderr_limit,
176
+ )
177
+
178
+
179
+ def run_python_code(
180
+ code: str,
181
+ *,
182
+ cwd: str,
183
+ timeout_s: float,
184
+ stdout_limit: int,
185
+ stderr_limit: int,
186
+ ) -> PythonRunResult:
187
+ config = {"mode": "code", "code": code, **_limit_config(timeout_s)}
188
+ return _run_python_command(
189
+ config,
190
+ cwd=cwd,
191
+ argv=[],
192
+ timeout_s=timeout_s,
193
+ stdout_limit=stdout_limit,
194
+ stderr_limit=stderr_limit,
195
+ )
server/session_manager.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import threading
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ from server.dataops_env_environment import DataOpsEnvironment
10
+
11
+
12
+ @dataclass
13
+ class SessionRecord:
14
+ env: DataOpsEnvironment
15
+ last_access_at: float
16
+
17
+
18
+ class EnvironmentSessionManager:
19
+ """Small in-memory session store for isolated environment instances."""
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ max_sessions: int = 128,
25
+ session_timeout_s: float = 1800.0,
26
+ ) -> None:
27
+ self._lock = threading.Lock()
28
+ self._sessions: dict[str, SessionRecord] = {}
29
+ self._max_sessions = max(1, max_sessions)
30
+ self._session_timeout_s = max(1.0, session_timeout_s)
31
+
32
+ def reset_session(
33
+ self,
34
+ *,
35
+ task_id: str,
36
+ seed: Optional[int],
37
+ episode_id: Optional[str],
38
+ session_id: Optional[str],
39
+ ) -> tuple[str, DataOpsEnvironment, object]:
40
+ now = time.monotonic()
41
+ to_close: list[DataOpsEnvironment] = []
42
+
43
+ with self._lock:
44
+ to_close.extend(self._collect_expired_envs_locked(now))
45
+
46
+ record = self._sessions.get(session_id) if session_id else None
47
+ if record is None:
48
+ resolved_session_id = str(uuid.uuid4())
49
+ to_close.extend(self._evict_if_full_locked(now))
50
+ env = DataOpsEnvironment()
51
+ self._sessions[resolved_session_id] = SessionRecord(
52
+ env=env,
53
+ last_access_at=now,
54
+ )
55
+ else:
56
+ resolved_session_id = session_id or str(uuid.uuid4())
57
+ record.last_access_at = now
58
+ env = record.env
59
+
60
+ self._close_envs(to_close)
61
+ obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id)
62
+ return resolved_session_id, env, obs
63
+
64
+ def get_session(
65
+ self, session_id: Optional[str]
66
+ ) -> tuple[Optional[str], Optional[DataOpsEnvironment]]:
67
+ now = time.monotonic()
68
+ to_close: list[DataOpsEnvironment] = []
69
+
70
+ with self._lock:
71
+ to_close.extend(self._collect_expired_envs_locked(now))
72
+
73
+ if session_id:
74
+ record = self._sessions.get(session_id)
75
+ if record is not None:
76
+ record.last_access_at = now
77
+ env = record.env
78
+ else:
79
+ env = None
80
+ result = (session_id, env)
81
+ else:
82
+ result = (None, None)
83
+
84
+ self._close_envs(to_close)
85
+ return result
86
+
87
+ def close_all(self) -> None:
88
+ with self._lock:
89
+ records = list(self._sessions.values())
90
+ self._sessions.clear()
91
+
92
+ self._close_envs([record.env for record in records])
93
+
94
+ def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]:
95
+ expired_ids = [
96
+ session_id
97
+ for session_id, record in self._sessions.items()
98
+ if now - record.last_access_at > self._session_timeout_s
99
+ ]
100
+ return self._remove_sessions_locked(expired_ids)
101
+
102
+ def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]:
103
+ if len(self._sessions) < self._max_sessions:
104
+ return []
105
+
106
+ oldest_session_id = min(
107
+ self._sessions,
108
+ key=lambda session_id: self._sessions[session_id].last_access_at,
109
+ )
110
+ return self._remove_sessions_locked([oldest_session_id])
111
+
112
+ def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]:
113
+ removed: list[DataOpsEnvironment] = []
114
+ for session_id in session_ids:
115
+ record = self._sessions.pop(session_id, None)
116
+ if record is not None:
117
+ removed.append(record.env)
118
+ return removed
119
+
120
+ def _close_envs(self, envs: list[DataOpsEnvironment]) -> None:
121
+ for env in envs:
122
+ env.close()
123
+
124
+ def __del__(self) -> None:
125
+ try:
126
+ self.close_all()
127
+ except Exception:
128
+ pass
server/task_specs.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Seeded task metadata and deterministic scenario builders for DataOpsEnv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ import re
7
+ import textwrap
8
+ from dataclasses import dataclass
9
+ from datetime import date, timedelta
10
+ from typing import Any, Iterable
11
+
12
+ TASK_IDS = [
13
+ "task_1_easy_anomaly",
14
+ "task_2_medium_syntax",
15
+ "task_3_hard_e2e",
16
+ ]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class SQLPolicy:
21
+ allowed_commands: frozenset[str]
22
+ required_table: str
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class TaskMetadata:
27
+ task_id: str
28
+ name: str
29
+ difficulty: str
30
+ short_description: str
31
+ benchmark_focus: str
32
+ allowed_actions: tuple[str, ...]
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class Task1Scenario:
37
+ description: str
38
+ all_rows: tuple[dict[str, Any], ...]
39
+ expected_rows: tuple[dict[str, Any], ...]
40
+ corrupted_row_ids: tuple[int, ...]
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class Task2Scenario:
45
+ description: str
46
+ visible_batch: tuple[dict[str, Any], ...]
47
+ visible_expected: tuple[dict[str, Any], ...]
48
+ hidden_cases: tuple[tuple[dict[str, Any], ...], ...]
49
+ hidden_expected: tuple[tuple[dict[str, Any], ...], ...]
50
+ broken_script: str
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class Task3Scenario:
55
+ description: str
56
+ target_date: str
57
+ recipient: str
58
+ subject: str
59
+ report_title: str
60
+ all_rows: tuple[dict[str, Any], ...]
61
+ expected_rows: tuple[dict[str, Any], ...]
62
+ broken_script: str
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class TaskScenarioBundle:
67
+ task_id: str
68
+ seed: int
69
+ description: str
70
+ task_1: Task1Scenario | None = None
71
+ task_2: Task2Scenario | None = None
72
+ task_3: Task3Scenario | None = None
73
+
74
+
75
+ TASK_METADATA = {
76
+ "task_1_easy_anomaly": TaskMetadata(
77
+ task_id="task_1_easy_anomaly",
78
+ name="Delete Corrupted Transaction Rows",
79
+ difficulty="easy",
80
+ short_description=(
81
+ "Inspect a transaction table and remove only the seeded rows with NULL amounts while preserving legitimate non-null edge values."
82
+ ),
83
+ benchmark_focus="Careful data cleanup without collateral damage.",
84
+ allowed_actions=("ExecuteSQL",),
85
+ ),
86
+ "task_2_medium_syntax": TaskMetadata(
87
+ task_id="task_2_medium_syntax",
88
+ name="Repair Seeded Pipeline Script",
89
+ difficulty="medium",
90
+ short_description=(
91
+ "Repair a seeded ETL normalization script and verify it on visible and hidden seeded batches."
92
+ ),
93
+ benchmark_focus="Code reading, precise repair, and generalization beyond the demo batch.",
94
+ allowed_actions=("ReadFile", "WriteFile", "RunScript"),
95
+ ),
96
+ "task_3_hard_e2e": TaskMetadata(
97
+ task_id="task_3_hard_e2e",
98
+ name="Resolve Revenue Reporting Incident",
99
+ difficulty="hard",
100
+ short_description=(
101
+ "Extract a seeded reporting slice, repair the formatter, and send the exact generated report."
102
+ ),
103
+ benchmark_focus="End-to-end data extraction, file repair, and communication with provenance.",
104
+ allowed_actions=("ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail"),
105
+ ),
106
+ }
107
+
108
+ TASK_DESCRIPTIONS = {
109
+ task_id: metadata.short_description for task_id, metadata in TASK_METADATA.items()
110
+ }
111
+
112
+ TASK_ALLOWED_WRITE_FILES = {
113
+ "task_1_easy_anomaly": frozenset(),
114
+ "task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
115
+ "task_3_hard_e2e": frozenset({"format_report.py", "report_data.json"}),
116
+ }
117
+
118
+ TASK_ALLOWED_RUN_FILES = {
119
+ "task_1_easy_anomaly": frozenset(),
120
+ "task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
121
+ "task_3_hard_e2e": frozenset({"format_report.py"}),
122
+ }
123
+
124
+ TASK_EMAIL_ENABLED = frozenset({"task_3_hard_e2e"})
125
+
126
+ TASK_ALLOWED_READ_FILES = {
127
+ "task_1_easy_anomaly": frozenset(),
128
+ "task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
129
+ "task_3_hard_e2e": frozenset({"format_report.py", "report_data.json"}),
130
+ }
131
+
132
+ TASK_SQL_POLICIES = {
133
+ "task_1_easy_anomaly": SQLPolicy(
134
+ allowed_commands=frozenset({"SELECT", "DELETE"}),
135
+ required_table="transactions",
136
+ ),
137
+ "task_3_hard_e2e": SQLPolicy(
138
+ allowed_commands=frozenset({"SELECT", "WITH"}),
139
+ required_table="daily_reports",
140
+ ),
141
+ }
142
+
143
+ _REPORT_RECORD_RE = re.compile(
144
+ r"Department:\s*(?P<department>[^\n]+)\n"
145
+ r"\s*Revenue:\s*\$(?P<revenue>-?\d+(?:\.\d+)?)\n"
146
+ r"\s*Expenses:\s*\$(?P<expenses>-?\d+(?:\.\d+)?)\n"
147
+ r"\s*Net:\s*\$(?P<net>-?\d+(?:\.\d+)?)",
148
+ re.MULTILINE,
149
+ )
150
+ _REPORT_TOTAL_RE = re.compile(r"Total Revenue:\s*\$(?P<total>-?\d+(?:\.\d+)?)")
151
+
152
+ _TASK_1_VALID_STATUSES = ("success", "settled", "approved", "completed")
153
+ _TASK_1_CORRUPTED_STATUSES = ("pending", "retrying", "failed", "queued")
154
+ _TASK_2_READY_STATUS = "ready"
155
+ _TASK_2_NON_READY_STATUSES = ("queued", "hold", "failed")
156
+ _TASK_2_REGIONS = ("us-east", "eu-west", "ap-south", "sa-east")
157
+ _TASK_3_RECIPIENTS = (
158
+ "bhavik@example.com",
159
+ "marta@example.com",
160
+ "ops-lead@example.com",
161
+ "finance-review@example.com",
162
+ )
163
+ _TASK_3_DEPARTMENTS = (
164
+ "Engineering",
165
+ "Sales",
166
+ "Marketing",
167
+ "Operations",
168
+ "Support",
169
+ "Finance",
170
+ )
171
+
172
+
173
+ def task_manifest_entries() -> list[dict[str, Any]]:
174
+ return [
175
+ {
176
+ "id": metadata.task_id,
177
+ "name": metadata.name,
178
+ "difficulty": metadata.difficulty,
179
+ "description": metadata.short_description,
180
+ "benchmark_focus": metadata.benchmark_focus,
181
+ "allowed_actions": list(metadata.allowed_actions),
182
+ }
183
+ for metadata in TASK_METADATA.values()
184
+ ]
185
+
186
+
187
+ def build_task_scenario(task_id: str, seed: int | None = None) -> TaskScenarioBundle:
188
+ resolved_seed = 0 if seed is None else int(seed)
189
+
190
+ if task_id == "task_1_easy_anomaly":
191
+ task = _build_task_1_scenario(resolved_seed)
192
+ return TaskScenarioBundle(
193
+ task_id=task_id,
194
+ seed=resolved_seed,
195
+ description=task.description,
196
+ task_1=task,
197
+ )
198
+ if task_id == "task_2_medium_syntax":
199
+ task = _build_task_2_scenario(resolved_seed)
200
+ return TaskScenarioBundle(
201
+ task_id=task_id,
202
+ seed=resolved_seed,
203
+ description=task.description,
204
+ task_2=task,
205
+ )
206
+ if task_id == "task_3_hard_e2e":
207
+ task = _build_task_3_scenario(resolved_seed)
208
+ return TaskScenarioBundle(
209
+ task_id=task_id,
210
+ seed=resolved_seed,
211
+ description=task.description,
212
+ task_3=task,
213
+ )
214
+ raise KeyError(f"Unknown task_id: {task_id}")
215
+
216
+
217
+ def normalize_task_3_rows(
218
+ rows: Iterable[dict[str, Any]], *, require_headcount: bool = False
219
+ ) -> list[dict[str, Any]]:
220
+ """Normalise extracted rows for deterministic comparison."""
221
+ normalised: list[dict[str, Any]] = []
222
+ for row in rows:
223
+ try:
224
+ hc_raw = row.get("headcount")
225
+ if hc_raw is None or hc_raw == "":
226
+ if require_headcount:
227
+ return []
228
+ headcount: int | None = None
229
+ else:
230
+ headcount = int(hc_raw)
231
+ normalised.append(
232
+ {
233
+ "department": str(row["department"]),
234
+ "revenue": round(float(row["revenue"]), 2),
235
+ "expenses": round(float(row["expenses"]), 2),
236
+ "headcount": headcount,
237
+ }
238
+ )
239
+ except (KeyError, TypeError, ValueError):
240
+ return []
241
+ return sorted(normalised, key=lambda item: item["department"])
242
+
243
+
244
+ def normalize_task_2_output_rows(rows: Any) -> list[dict[str, Any]]:
245
+ """Normalise Task 2 ETL output rows while preserving list order for sort checks."""
246
+ if not isinstance(rows, list):
247
+ return []
248
+
249
+ normalised: list[dict[str, Any]] = []
250
+ for row in rows:
251
+ if not isinstance(row, dict):
252
+ return []
253
+ try:
254
+ order_id = str(row["order_id"])
255
+ region = str(row["region"])
256
+ amount_usd = round(float(row["amount_usd"]), 2)
257
+ priority_band = str(row["priority_band"])
258
+ except (KeyError, TypeError, ValueError):
259
+ return []
260
+
261
+ if priority_band not in {"high", "normal"}:
262
+ return []
263
+
264
+ normalised.append(
265
+ {
266
+ "order_id": order_id,
267
+ "region": region,
268
+ "amount_usd": amount_usd,
269
+ "priority_band": priority_band,
270
+ }
271
+ )
272
+
273
+ return normalised
274
+
275
+
276
+ def build_task_2_expected(
277
+ batch: Iterable[dict[str, Any]]
278
+ ) -> list[dict[str, Any]]:
279
+ processed: list[dict[str, Any]] = []
280
+
281
+ for record in batch:
282
+ try:
283
+ status = str(record["status"])
284
+ amount_cents = int(record["amount_cents"])
285
+ priority = int(record["priority"])
286
+ amount_usd = round(amount_cents / 100.0, 2)
287
+ if status != _TASK_2_READY_STATUS or amount_cents <= 0:
288
+ continue
289
+ processed.append(
290
+ {
291
+ "order_id": str(record["order_id"]),
292
+ "region": str(record["region"]),
293
+ "amount_usd": amount_usd,
294
+ "priority_band": "high"
295
+ if priority >= 8 or amount_usd >= 500.0
296
+ else "normal",
297
+ }
298
+ )
299
+ except (KeyError, TypeError, ValueError):
300
+ return []
301
+
302
+ processed.sort(key=lambda item: (-item["amount_usd"], item["order_id"]))
303
+ return processed
304
+
305
+
306
+ def task_3_data_matches_expected(
307
+ rows: list[dict[str, Any]],
308
+ expected_rows: Iterable[dict[str, Any]],
309
+ *,
310
+ require_headcount: bool,
311
+ ) -> bool:
312
+ expected = normalize_task_3_rows(expected_rows, require_headcount=require_headcount)
313
+ return rows == expected
314
+
315
+
316
+ def task_3_headcount_fully_matches(
317
+ rows: list[dict[str, Any]], expected_rows: Iterable[dict[str, Any]]
318
+ ) -> bool:
319
+ expected = normalize_task_3_rows(expected_rows, require_headcount=True)
320
+ return rows == expected
321
+
322
+
323
+ def build_task_3_report(rows: Iterable[dict[str, Any]], target_date: str) -> str:
324
+ report_rows = normalize_task_3_rows(rows, require_headcount=True)
325
+ lines = [f"=== Daily Revenue Report ({target_date}) ===", ""]
326
+ total_revenue = 0.0
327
+
328
+ for row in report_rows:
329
+ revenue = float(row["revenue"])
330
+ expenses = float(row["expenses"])
331
+ net = revenue - expenses
332
+ lines.append(f"Department: {row['department']}")
333
+ lines.append(f" Revenue: ${revenue:.2f}")
334
+ lines.append(f" Expenses: ${expenses:.2f}")
335
+ lines.append(f" Net: ${net:.2f}")
336
+ lines.append("")
337
+ total_revenue += revenue
338
+
339
+ lines.append(f"Total Revenue: ${total_revenue:.2f}")
340
+ lines.append("=== End of Report ===")
341
+ return "\n".join(lines)
342
+
343
+
344
+ def extract_task_3_report_block(text: str, target_date: str) -> str | None:
345
+ raw = text.replace("\r\n", "\n")
346
+ start_marker = f"=== Daily Revenue Report ({target_date}) ==="
347
+ start = raw.find(start_marker)
348
+ end_marker = "=== End of Report ==="
349
+ end = raw.find(end_marker)
350
+ if start == -1 or end == -1 or end < start:
351
+ return None
352
+ return raw[start : end + len(end_marker)].strip()
353
+
354
+
355
+ def parse_task_3_report(text: str, target_date: str) -> dict[str, Any] | None:
356
+ block = extract_task_3_report_block(text, target_date)
357
+ if block is None:
358
+ return None
359
+
360
+ records: list[dict[str, Any]] = []
361
+ for match in _REPORT_RECORD_RE.finditer(block):
362
+ revenue = round(float(match.group("revenue")), 2)
363
+ expenses = round(float(match.group("expenses")), 2)
364
+ net = round(float(match.group("net")), 2)
365
+ records.append(
366
+ {
367
+ "department": match.group("department").strip(),
368
+ "revenue": revenue,
369
+ "expenses": expenses,
370
+ "headcount": None,
371
+ "net": net,
372
+ }
373
+ )
374
+
375
+ total_match = _REPORT_TOTAL_RE.search(block)
376
+ if not total_match:
377
+ return None
378
+
379
+ return {
380
+ "records": sorted(records, key=lambda item: item["department"]),
381
+ "total_revenue": round(float(total_match.group("total")), 2),
382
+ }
383
+
384
+
385
+ def report_matches_expected(
386
+ text: str, expected_rows: Iterable[dict[str, Any]], target_date: str
387
+ ) -> bool:
388
+ parsed = parse_task_3_report(text, target_date)
389
+ if parsed is None:
390
+ return False
391
+
392
+ expected = normalize_task_3_rows(expected_rows, require_headcount=True)
393
+ expected_records = [
394
+ {
395
+ "department": row["department"],
396
+ "revenue": row["revenue"],
397
+ "expenses": row["expenses"],
398
+ "headcount": None,
399
+ "net": round(float(row["revenue"]) - float(row["expenses"]), 2),
400
+ }
401
+ for row in expected
402
+ ]
403
+ expected_total = round(sum(float(row["revenue"]) for row in expected), 2)
404
+ return (
405
+ parsed["records"] == expected_records
406
+ and parsed["total_revenue"] == expected_total
407
+ )
408
+
409
+
410
+ def task_3_semantic_match_fraction_rows(
411
+ rows: list[dict[str, Any]], expected_rows: Iterable[dict[str, Any]]
412
+ ) -> float:
413
+ if not rows:
414
+ return 0.0
415
+ expected = normalize_task_3_rows(expected_rows, require_headcount=False)
416
+ exp_by_dept = {row["department"]: row for row in expected}
417
+ matched = 0
418
+ for row in rows:
419
+ department = row.get("department")
420
+ if department not in exp_by_dept:
421
+ continue
422
+ expected_row = exp_by_dept[department]
423
+ if (
424
+ row.get("revenue") == expected_row["revenue"]
425
+ and row.get("expenses") == expected_row["expenses"]
426
+ ):
427
+ matched += 1
428
+ return matched / len(expected) if expected else 0.0
429
+
430
+
431
+ def task_3_semantic_match_fraction_parsed(
432
+ parsed: dict[str, Any] | None, expected_rows: Iterable[dict[str, Any]]
433
+ ) -> float:
434
+ if not parsed or not parsed.get("records"):
435
+ return 0.0
436
+ expected = normalize_task_3_rows(expected_rows, require_headcount=False)
437
+ exp_by_dept = {row["department"]: row for row in expected}
438
+ matched = 0
439
+ for record in parsed["records"]:
440
+ department = record.get("department")
441
+ if department not in exp_by_dept:
442
+ continue
443
+ expected_row = exp_by_dept[department]
444
+ if (
445
+ record.get("revenue") == expected_row["revenue"]
446
+ and record.get("expenses") == expected_row["expenses"]
447
+ ):
448
+ matched += 1
449
+ return matched / len(expected) if expected else 0.0
450
+
451
+
452
+ def task_3_semantic_match_fraction_text(
453
+ text: str, expected_rows: Iterable[dict[str, Any]], target_date: str
454
+ ) -> float:
455
+ return task_3_semantic_match_fraction_parsed(
456
+ parse_task_3_report(text, target_date), expected_rows
457
+ )
458
+
459
+
460
+ def _build_task_1_scenario(seed: int) -> Task1Scenario:
461
+ rng = random.Random(f"task-1:{seed}")
462
+ valid_count = 3 + rng.randrange(3)
463
+ corrupted_count = 2 + rng.randrange(2)
464
+ combined_rows: list[dict[str, Any]] = []
465
+
466
+ valid_templates = []
467
+ for index in range(valid_count):
468
+ valid_templates.append(
469
+ {
470
+ "kind": "valid",
471
+ "user_id": 1000 + seed * 10 + index,
472
+ "amount": round(rng.uniform(75.0, 975.0), 2),
473
+ "status": rng.choice(_TASK_1_VALID_STATUSES),
474
+ }
475
+ )
476
+ if valid_templates:
477
+ valid_templates[0]["amount"] = 0.0
478
+ valid_templates[0]["status"] = "settled"
479
+ if len(valid_templates) > 1:
480
+ valid_templates[1]["amount"] = -round(float(valid_templates[1]["amount"]) / 10.0, 2)
481
+ valid_templates[1]["status"] = "approved"
482
+ corrupted_templates = []
483
+ for index in range(corrupted_count):
484
+ corrupted_templates.append(
485
+ {
486
+ "kind": "corrupted",
487
+ "user_id": 2000 + seed * 10 + index,
488
+ "amount": None,
489
+ "status": rng.choice(_TASK_1_CORRUPTED_STATUSES),
490
+ }
491
+ )
492
+
493
+ templates = valid_templates + corrupted_templates
494
+ rng.shuffle(templates)
495
+
496
+ expected_rows: list[dict[str, Any]] = []
497
+ corrupted_row_ids: list[int] = []
498
+ for row_id, template in enumerate(templates, start=1):
499
+ row = {
500
+ "id": row_id,
501
+ "user_id": int(template["user_id"]),
502
+ "amount": template["amount"],
503
+ "status": str(template["status"]),
504
+ }
505
+ combined_rows.append(row)
506
+ if template["kind"] == "valid":
507
+ expected_rows.append(row)
508
+ else:
509
+ corrupted_row_ids.append(row_id)
510
+
511
+ description = (
512
+ "Find and delete all corrupted records (rows with NULL amounts) from the "
513
+ f"'transactions' table. This seeded episode contains {corrupted_count} corrupted "
514
+ f"rows mixed with {valid_count} valid rows. Only NULL amounts are corrupted; "
515
+ "legitimate zero-value reconciliations and negative refund adjustments may also "
516
+ "appear and must be preserved exactly."
517
+ )
518
+ return Task1Scenario(
519
+ description=description,
520
+ all_rows=tuple(combined_rows),
521
+ expected_rows=tuple(expected_rows),
522
+ corrupted_row_ids=tuple(sorted(corrupted_row_ids)),
523
+ )
524
+
525
+
526
+ def _build_task_2_scenario(seed: int) -> Task2Scenario:
527
+ rng = random.Random(f"task-2:{seed}")
528
+ visible_batch = _sample_task_2_batch(rng, batch_index=0)
529
+ hidden_cases = tuple(
530
+ _sample_task_2_batch(rng, batch_index=index + 1)
531
+ for index in range(6)
532
+ )
533
+ visible_expected = tuple(build_task_2_expected(visible_batch))
534
+ hidden_expected = tuple(
535
+ tuple(build_task_2_expected(batch)) for batch in hidden_cases
536
+ )
537
+ description = (
538
+ "The script 'broken_pipeline.py' prepares downstream billing candidates from "
539
+ "seeded order records. Repair it so it keeps only ready records with positive "
540
+ "amounts, converts cents to USD, flags high priority when priority >= 8 or "
541
+ "amount_usd >= 500.00, and returns rows sorted by amount_usd descending then "
542
+ "order_id ascending. The grader checks the visible demo batch and additional "
543
+ "unseen seeded batches."
544
+ )
545
+ return Task2Scenario(
546
+ description=description,
547
+ visible_batch=visible_batch,
548
+ visible_expected=visible_expected,
549
+ hidden_cases=hidden_cases,
550
+ hidden_expected=hidden_expected,
551
+ broken_script=_render_broken_pipeline_script(visible_batch),
552
+ )
553
+
554
+
555
+ def _build_task_3_scenario(seed: int) -> Task3Scenario:
556
+ rng = random.Random(f"task-3:{seed}")
557
+ base_date = date(2025, 3, 25) + timedelta(days=rng.randrange(0, 7))
558
+ target_date = base_date.isoformat()
559
+ recipient = rng.choice(_TASK_3_RECIPIENTS)
560
+ subject = f"Daily Revenue Report - {target_date}"
561
+ report_title = f"Daily Revenue Report ({target_date})"
562
+ selected_departments = sorted(rng.sample(_TASK_3_DEPARTMENTS, k=4))
563
+
564
+ expected_rows: list[dict[str, Any]] = []
565
+ warehouse_rows: list[dict[str, Any]] = []
566
+ row_id = 1
567
+ for offset in (-2, -1, 0, 1):
568
+ report_date = (base_date + timedelta(days=offset)).isoformat()
569
+ for department in selected_departments:
570
+ if offset == 0:
571
+ revenue = round(rng.uniform(12_000.0, 95_000.0), 2)
572
+ expenses = round(rng.uniform(8_000.0, revenue + 18_000.0), 2)
573
+ headcount = rng.randint(8, 48)
574
+ seeded_row = {
575
+ "department": department,
576
+ "revenue": revenue,
577
+ "expenses": expenses,
578
+ "headcount": headcount,
579
+ }
580
+ expected_rows.append(seeded_row)
581
+ else:
582
+ revenue = round(rng.uniform(9_000.0, 90_000.0), 2)
583
+ expenses = round(rng.uniform(7_000.0, revenue + 14_000.0), 2)
584
+ headcount = rng.randint(8, 48)
585
+ warehouse_rows.append(
586
+ {
587
+ "id": row_id,
588
+ "report_date": report_date,
589
+ "department": department,
590
+ "revenue": revenue,
591
+ "expenses": expenses,
592
+ "headcount": headcount,
593
+ }
594
+ )
595
+ row_id += 1
596
+
597
+ description = (
598
+ f"Extract the daily report for date '{target_date}' from the 'daily_reports' table, "
599
+ "repair the broken 'format_report.py' script, save the exact extracted rows to "
600
+ f"'report_data.json', run the script with that file, and send the generated report "
601
+ f"to '{recipient}' with subject '{subject}'. The grader expects the exact seeded slice, "
602
+ "including headcount."
603
+ )
604
+ return Task3Scenario(
605
+ description=description,
606
+ target_date=target_date,
607
+ recipient=recipient,
608
+ subject=subject,
609
+ report_title=report_title,
610
+ all_rows=tuple(warehouse_rows),
611
+ expected_rows=tuple(
612
+ normalize_task_3_rows(expected_rows, require_headcount=True)
613
+ ),
614
+ broken_script=_render_broken_format_report_script(target_date),
615
+ )
616
+
617
+
618
+ def _sample_task_2_batch(
619
+ rng: random.Random, *, batch_index: int
620
+ ) -> tuple[dict[str, Any], ...]:
621
+ def make_record(
622
+ suffix: str,
623
+ *,
624
+ status: str,
625
+ amount_cents: int,
626
+ priority: int,
627
+ ) -> dict[str, Any]:
628
+ return {
629
+ "order_id": f"ORD-{batch_index:02d}-{suffix}",
630
+ "status": status,
631
+ "amount_cents": amount_cents,
632
+ "priority": priority,
633
+ "region": rng.choice(_TASK_2_REGIONS),
634
+ }
635
+
636
+ records = [
637
+ make_record(
638
+ "normal",
639
+ status=_TASK_2_READY_STATUS,
640
+ amount_cents=rng.randrange(12_125, 28_975, 25),
641
+ priority=rng.randint(2, 6),
642
+ ),
643
+ make_record(
644
+ "priority",
645
+ status=_TASK_2_READY_STATUS,
646
+ amount_cents=rng.randrange(13_175, 32_775, 25),
647
+ priority=rng.randint(8, 10),
648
+ ),
649
+ make_record(
650
+ "amount",
651
+ status=_TASK_2_READY_STATUS,
652
+ amount_cents=rng.randrange(50_025, 88_975, 25),
653
+ priority=rng.randint(2, 6),
654
+ ),
655
+ make_record(
656
+ "queued",
657
+ status=rng.choice(_TASK_2_NON_READY_STATUSES[:2]),
658
+ amount_cents=rng.randrange(18_125, 42_975, 25),
659
+ priority=rng.randint(4, 9),
660
+ ),
661
+ make_record(
662
+ "drop",
663
+ status=_TASK_2_READY_STATUS,
664
+ amount_cents=-rng.randrange(125, 2_975, 25),
665
+ priority=rng.randint(8, 10),
666
+ ),
667
+ ]
668
+
669
+ if batch_index % 2 == 0:
670
+ records.append(
671
+ make_record(
672
+ "hold",
673
+ status=rng.choice(_TASK_2_NON_READY_STATUSES),
674
+ amount_cents=rng.randrange(24_125, 48_975, 25),
675
+ priority=rng.randint(1, 7),
676
+ )
677
+ )
678
+
679
+ rng.shuffle(records)
680
+ return tuple(records)
681
+
682
+
683
+ def _render_broken_pipeline_script(
684
+ visible_batch: tuple[dict[str, Any], ...]
685
+ ) -> str:
686
+ return textwrap.dedent(
687
+ f'''\
688
+ import json
689
+
690
+
691
+ def process_data_stream(payloads):
692
+ """
693
+ Normalize downstream billing candidates.
694
+
695
+ Keep only records whose status is "ready" and whose amount_cents is positive.
696
+ Convert amount_cents to amount_usd rounded to 2 decimals.
697
+ Mark priority_band as "high" when priority >= 8 or amount_usd >= 500.00.
698
+ Return rows sorted by amount_usd descending, then order_id ascending.
699
+ """
700
+ processed_records = []
701
+
702
+ for payload in payloads:
703
+ if payload["status"] == "failed" or payload["amount_cents"] <= 0:
704
+ continue
705
+
706
+ amount_usd = round(payload["amount_cents"] // 100, 2)
707
+ priority_band = (
708
+ "high"
709
+ if payload["priority"] >= 8 and amount_usd >= 500.0
710
+ else "normal"
711
+ )
712
+ processed_records.append(
713
+ {{
714
+ "order_id": payload["order_id"],
715
+ "region": payload["region"],
716
+ "amount_usd": amount_usd,
717
+ "priority_band": priority_band,
718
+ }}
719
+ )
720
+
721
+ processed_records.sort(key=lambda item: (item["amount_usd"], item["order_id"]))
722
+ return processed_records
723
+
724
+ if __name__ == "__main__":
725
+ mock_batch = {list(visible_batch)!r}
726
+ print(json.dumps(process_data_stream(mock_batch), indent=2, sort_keys=True))
727
+ '''
728
+ ).lstrip()
729
+
730
+
731
+ def _render_broken_format_report_script(target_date: str) -> str:
732
+ title = f"=== Daily Revenue Report ({target_date}) ==="
733
+ return textwrap.dedent(
734
+ f'''\
735
+ import json
736
+ import sys
737
+
738
+
739
+ def format_report(input_path):
740
+ """Reads extracted data from JSON and produces a formatted stakeholder report."""
741
+ with open(input_path, encoding="utf-8") as f:
742
+ records = json.load(f)
743
+
744
+ lines = ["{title}", ""]
745
+ total_revenue = 0
746
+
747
+ for rec in records:
748
+ dept = rec["department"]
749
+ rev = int(rec["revenue"]) # BUG 1: int() truncates decimal precision
750
+ exp = rec["expenses"]
751
+ net = rev - exp
752
+ lines.append(f"Department: {{dept}}")
753
+ lines.append(f" Revenue: ${{rev}}")
754
+ lines.append(f" Expenses: ${{exp:.2f}}")
755
+ lines.append(f" Net: ${{net:.2f}}")
756
+ lines.append("")
757
+ total_revenue += rev
758
+
759
+ lines.append(f"Total Revenue: ${{total_revenue}}")
760
+ lines.append("=== End of Report ===")
761
+
762
+ output = "\\n".join(lines)
763
+ print(output)
764
+ return output
765
+
766
+
767
+ if __name__ == "__main__":
768
+ if len(sys.argv) < 2:
769
+ print("Usage: python format_report.py <input.json>", file=sys.stderr)
770
+ sys.exit(1)
771
+ format_report(sys.argv[0]) # BUG 2: should be sys.argv[1]
772
+ '''
773
+ ).lstrip()
tests/test_grading.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High-signal regression tests for seeded grading and public API shape."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import shutil
7
+ import sqlite3
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from fastapi.testclient import TestClient
12
+
13
+ _ROOT = Path(__file__).resolve().parents[1]
14
+ if str(_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(_ROOT))
16
+
17
+ from models import DataOpsAction # noqa: E402
18
+ from server.app import app # noqa: E402
19
+ from server.dataops_env_environment import DataOpsEnvironment # noqa: E402
20
+ from server.grading import evaluate_task # noqa: E402
21
+ from server.task_specs import build_task_3_report # noqa: E402
22
+
23
+
24
+ def _fixed_pipeline_script(visible_batch: list[dict[str, object]]) -> str:
25
+ return f'''\
26
+ import json
27
+
28
+
29
+ def process_data_stream(payloads):
30
+ processed_records = []
31
+ for payload in payloads:
32
+ if payload["status"] != "ready" or int(payload["amount_cents"]) <= 0:
33
+ continue
34
+ amount_usd = round(int(payload["amount_cents"]) / 100.0, 2)
35
+ priority_band = (
36
+ "high"
37
+ if int(payload["priority"]) >= 8 or amount_usd >= 500.0
38
+ else "normal"
39
+ )
40
+ processed_records.append(
41
+ {{
42
+ "order_id": payload["order_id"],
43
+ "region": payload["region"],
44
+ "amount_usd": amount_usd,
45
+ "priority_band": priority_band,
46
+ }}
47
+ )
48
+ processed_records.sort(key=lambda item: (-item["amount_usd"], item["order_id"]))
49
+ return processed_records
50
+
51
+
52
+ if __name__ == "__main__":
53
+ mock_batch = {visible_batch!r}
54
+ print(json.dumps(process_data_stream(mock_batch), indent=2, sort_keys=True))
55
+ '''
56
+
57
+
58
+ def _visible_only_pipeline_stub(
59
+ visible_batch: list[dict[str, object]],
60
+ visible_expected: list[dict[str, object]],
61
+ ) -> str:
62
+ return f'''\
63
+ import json
64
+
65
+
66
+ def process_data_stream(payloads):
67
+ visible = {visible_batch!r}
68
+ if payloads == visible:
69
+ return {visible_expected!r}
70
+ return []
71
+
72
+
73
+ if __name__ == "__main__":
74
+ print(json.dumps({visible_expected!r}, indent=2, sort_keys=True))
75
+ '''
76
+
77
+
78
+ def _fixed_format_script(target_date: str) -> str:
79
+ return f'''\
80
+ import json
81
+ import sys
82
+
83
+
84
+ def format_report(input_path):
85
+ with open(input_path, encoding="utf-8") as f:
86
+ records = json.load(f)
87
+ lines = ["=== Daily Revenue Report ({target_date}) ===", ""]
88
+ total_revenue = 0.0
89
+ for rec in records:
90
+ dept = rec["department"]
91
+ rev = float(rec["revenue"])
92
+ exp = float(rec["expenses"])
93
+ net = rev - exp
94
+ lines.append(f"Department: {{dept}}")
95
+ lines.append(f" Revenue: ${{rev:.2f}}")
96
+ lines.append(f" Expenses: ${{exp:.2f}}")
97
+ lines.append(f" Net: ${{net:.2f}}")
98
+ lines.append("")
99
+ total_revenue += rev
100
+ lines.append(f"Total Revenue: ${{total_revenue:.2f}}")
101
+ lines.append("=== End of Report ===")
102
+ out = "\\n".join(lines)
103
+ print(out)
104
+ return out
105
+
106
+
107
+ if __name__ == "__main__":
108
+ if len(sys.argv) < 2:
109
+ print("Usage: python format_report.py <input.json>", file=sys.stderr)
110
+ sys.exit(1)
111
+ format_report(sys.argv[1])
112
+ '''
113
+
114
+
115
+ def test_seeded_task_3_scenario_is_deterministic() -> None:
116
+ env_a = DataOpsEnvironment()
117
+ env_b = DataOpsEnvironment()
118
+ try:
119
+ env_a.reset(task_id="task_3_hard_e2e", seed=17)
120
+ env_b.reset(task_id="task_3_hard_e2e", seed=17)
121
+ assert env_a.scenario.task_3 == env_b.scenario.task_3
122
+ finally:
123
+ env_a.close()
124
+ env_b.close()
125
+
126
+
127
+ def test_task_1_perfect_score_seeded() -> None:
128
+ env = DataOpsEnvironment()
129
+ env.reset(task_id="task_1_easy_anomaly", seed=7)
130
+ try:
131
+ obs = env.step(
132
+ DataOpsAction(
133
+ action_type="ExecuteSQL",
134
+ payload={"query": "DELETE FROM transactions WHERE amount IS NULL"},
135
+ )
136
+ )
137
+ assert obs.status == "success"
138
+ out = evaluate_task("task_1_easy_anomaly", env)
139
+ assert out["score"] == 1.0
140
+ finally:
141
+ env.close()
142
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
143
+
144
+
145
+ def test_task_1_seeded_valid_rows_include_non_null_edge_amounts() -> None:
146
+ env = DataOpsEnvironment()
147
+ env.reset(task_id="task_1_easy_anomaly", seed=7)
148
+ try:
149
+ scenario = env.scenario.task_1
150
+ assert scenario is not None
151
+ amounts = [float(row["amount"]) for row in scenario.expected_rows]
152
+ assert any(amount == 0.0 for amount in amounts)
153
+ assert any(amount < 0.0 for amount in amounts)
154
+ finally:
155
+ env.close()
156
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
157
+
158
+
159
+ def test_task_1_rewriting_corrupted_rows_scores_zero() -> None:
160
+ env = DataOpsEnvironment()
161
+ env.reset(task_id="task_1_easy_anomaly", seed=7)
162
+ try:
163
+ with sqlite3.connect(env.db_path) as conn:
164
+ conn.execute("UPDATE transactions SET amount = 0 WHERE amount IS NULL")
165
+ conn.commit()
166
+ out = evaluate_task("task_1_easy_anomaly", env)
167
+ assert out["score"] == 0.0
168
+ finally:
169
+ env.close()
170
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
171
+
172
+
173
+ def test_task_1_deleting_non_null_adjustments_is_penalized() -> None:
174
+ env = DataOpsEnvironment()
175
+ env.reset(task_id="task_1_easy_anomaly", seed=7)
176
+ try:
177
+ obs = env.step(
178
+ DataOpsAction(
179
+ action_type="ExecuteSQL",
180
+ payload={
181
+ "query": "DELETE FROM transactions WHERE amount IS NULL OR amount <= 0"
182
+ },
183
+ )
184
+ )
185
+ assert obs.status == "success"
186
+ assert obs.reward is not None and obs.reward < 0
187
+ out = evaluate_task("task_1_easy_anomaly", env)
188
+ assert out["score"] == 0.0
189
+ finally:
190
+ env.close()
191
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
192
+
193
+
194
+ def test_reset_only_scores_zero_across_tasks() -> None:
195
+ for task_id in (
196
+ "task_1_easy_anomaly",
197
+ "task_2_medium_syntax",
198
+ "task_3_hard_e2e",
199
+ ):
200
+ env = DataOpsEnvironment()
201
+ try:
202
+ env.reset(task_id=task_id, seed=7)
203
+ out = evaluate_task(task_id, env)
204
+ assert out["score"] == 0.0
205
+ finally:
206
+ env.close()
207
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
208
+
209
+
210
+ def test_task_1_broad_delete_with_where_is_penalized() -> None:
211
+ env = DataOpsEnvironment()
212
+ env.reset(task_id="task_1_easy_anomaly", seed=7)
213
+ try:
214
+ obs = env.step(
215
+ DataOpsAction(
216
+ action_type="ExecuteSQL",
217
+ payload={
218
+ "query": "DELETE FROM transactions WHERE amount IS NULL OR 1 = 1"
219
+ },
220
+ )
221
+ )
222
+ assert obs.status == "success"
223
+ assert obs.reward is not None and obs.reward < 0
224
+ assert env.evidence["task_1"]["destructive_sql_attempted"] is True
225
+ out = evaluate_task("task_1_easy_anomaly", env)
226
+ assert out["score"] == 0.0
227
+ finally:
228
+ env.close()
229
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
230
+
231
+
232
+ def test_task_2_script_run_does_not_inherit_server_secrets(monkeypatch) -> None:
233
+ monkeypatch.setenv("API_KEY", "super-secret-value")
234
+ env = DataOpsEnvironment()
235
+ env.reset(task_id="task_2_medium_syntax", seed=11)
236
+ script = """\
237
+ import json
238
+ import os
239
+
240
+
241
+ def process_data_stream(payloads):
242
+ return []
243
+
244
+
245
+ if __name__ == "__main__":
246
+ print(json.dumps({"api_key": os.getenv("API_KEY"), "home": os.getenv("HOME")}))
247
+ """
248
+ try:
249
+ env.step(
250
+ DataOpsAction(
251
+ action_type="WriteFile",
252
+ payload={"filepath": "broken_pipeline.py", "content": script},
253
+ )
254
+ )
255
+ run_obs = env.step(
256
+ DataOpsAction(
257
+ action_type="RunScript",
258
+ payload={"filepath": "broken_pipeline.py", "args": []},
259
+ )
260
+ )
261
+ assert run_obs.status == "success"
262
+ payload = json.loads((run_obs.stdout or "").strip())
263
+ assert payload["api_key"] is None
264
+ assert payload["home"] == env.workspace_dir
265
+ finally:
266
+ env.close()
267
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
268
+
269
+
270
+ def test_task_2_perfect_score_seeded() -> None:
271
+ env = DataOpsEnvironment()
272
+ env.reset(task_id="task_2_medium_syntax", seed=11)
273
+ scenario = env.scenario.task_2
274
+ assert scenario is not None
275
+ try:
276
+ read_obs = env.step(
277
+ DataOpsAction(
278
+ action_type="ReadFile",
279
+ payload={"filepath": "broken_pipeline.py"},
280
+ )
281
+ )
282
+ assert read_obs.status == "success"
283
+ write_obs = env.step(
284
+ DataOpsAction(
285
+ action_type="WriteFile",
286
+ payload={
287
+ "filepath": "broken_pipeline.py",
288
+ "content": _fixed_pipeline_script(list(scenario.visible_batch)),
289
+ },
290
+ )
291
+ )
292
+ assert write_obs.status == "success"
293
+ pre_run = evaluate_task("task_2_medium_syntax", env)
294
+ assert 0.0 < pre_run["score"] < 1.0
295
+ run_obs = env.step(
296
+ DataOpsAction(
297
+ action_type="RunScript",
298
+ payload={"filepath": "broken_pipeline.py", "args": []},
299
+ )
300
+ )
301
+ assert run_obs.status == "success"
302
+ out = evaluate_task("task_2_medium_syntax", env)
303
+ assert out["score"] == 1.0
304
+ finally:
305
+ env.close()
306
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
307
+
308
+
309
+ def test_task_2_print_only_stub_does_not_get_full_credit() -> None:
310
+ env = DataOpsEnvironment()
311
+ env.reset(task_id="task_2_medium_syntax", seed=11)
312
+ scenario = env.scenario.task_2
313
+ assert scenario is not None
314
+ stub = _visible_only_pipeline_stub(
315
+ list(scenario.visible_batch),
316
+ list(scenario.visible_expected),
317
+ )
318
+ try:
319
+ env.step(
320
+ DataOpsAction(
321
+ action_type="WriteFile",
322
+ payload={"filepath": "broken_pipeline.py", "content": stub},
323
+ )
324
+ )
325
+ out = evaluate_task("task_2_medium_syntax", env)
326
+ assert out["score"] < 0.5
327
+ finally:
328
+ env.close()
329
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
330
+
331
+
332
+ def test_task_3_sql_policy_rejects_literal_table_name_bypass() -> None:
333
+ env = DataOpsEnvironment()
334
+ env.reset(task_id="task_3_hard_e2e", seed=19)
335
+ try:
336
+ obs = env.step(
337
+ DataOpsAction(
338
+ action_type="ExecuteSQL",
339
+ payload={
340
+ "query": (
341
+ "SELECT name FROM sqlite_master "
342
+ "WHERE 'daily_reports' = 'daily_reports'"
343
+ )
344
+ },
345
+ )
346
+ )
347
+ assert obs.status == "error"
348
+ assert "disallowed" in obs.message.lower()
349
+ finally:
350
+ env.close()
351
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
352
+
353
+
354
+ def test_task_3_sql_policy_allows_cte_queries_over_daily_reports() -> None:
355
+ env = DataOpsEnvironment()
356
+ env.reset(task_id="task_3_hard_e2e", seed=19)
357
+ scenario = env.scenario.task_3
358
+ assert scenario is not None
359
+ try:
360
+ obs = env.step(
361
+ DataOpsAction(
362
+ action_type="ExecuteSQL",
363
+ payload={
364
+ "query": (
365
+ "WITH scoped AS ("
366
+ "SELECT department, revenue, expenses, headcount "
367
+ "FROM daily_reports "
368
+ f"WHERE report_date = '{scenario.target_date}'"
369
+ ") "
370
+ "SELECT department, revenue, expenses, headcount "
371
+ "FROM scoped ORDER BY department"
372
+ )
373
+ },
374
+ )
375
+ )
376
+ assert obs.status == "success"
377
+ assert obs.sql_results
378
+ finally:
379
+ env.close()
380
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
381
+
382
+
383
+ def test_task_3_perfect_score_requires_proven_workflow() -> None:
384
+ env = DataOpsEnvironment()
385
+ env.reset(task_id="task_3_hard_e2e", seed=19)
386
+ scenario = env.scenario.task_3
387
+ assert scenario is not None
388
+ try:
389
+ query = (
390
+ "SELECT department, revenue, expenses, headcount "
391
+ "FROM daily_reports "
392
+ f"WHERE report_date = '{scenario.target_date}' "
393
+ "ORDER BY department"
394
+ )
395
+ sql_obs = env.step(
396
+ DataOpsAction(
397
+ action_type="ExecuteSQL",
398
+ payload={"query": query},
399
+ )
400
+ )
401
+ assert sql_obs.status == "success"
402
+ rows = sql_obs.sql_results
403
+ assert rows is not None
404
+
405
+ write_json = env.step(
406
+ DataOpsAction(
407
+ action_type="WriteFile",
408
+ payload={"filepath": "report_data.json", "content": json.dumps(rows)},
409
+ )
410
+ )
411
+ assert write_json.status == "success"
412
+
413
+ write_script = env.step(
414
+ DataOpsAction(
415
+ action_type="WriteFile",
416
+ payload={
417
+ "filepath": "format_report.py",
418
+ "content": _fixed_format_script(scenario.target_date),
419
+ },
420
+ )
421
+ )
422
+ assert write_script.status == "success"
423
+
424
+ run_obs = env.step(
425
+ DataOpsAction(
426
+ action_type="RunScript",
427
+ payload={"filepath": "format_report.py", "args": ["report_data.json"]},
428
+ )
429
+ )
430
+ assert run_obs.status == "success"
431
+ body = (run_obs.stdout or "").strip()
432
+
433
+ email_obs = env.step(
434
+ DataOpsAction(
435
+ action_type="SendEmail",
436
+ payload={
437
+ "to_email": scenario.recipient,
438
+ "subject": scenario.subject,
439
+ "body": body,
440
+ },
441
+ )
442
+ )
443
+ assert email_obs.status == "success"
444
+
445
+ out = evaluate_task("task_3_hard_e2e", env)
446
+ assert out["score"] == 1.0
447
+ finally:
448
+ env.close()
449
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
450
+
451
+
452
+ def test_task_3_equivalent_relative_input_path_still_scores_perfect() -> None:
453
+ env = DataOpsEnvironment()
454
+ env.reset(task_id="task_3_hard_e2e", seed=29)
455
+ scenario = env.scenario.task_3
456
+ assert scenario is not None
457
+ try:
458
+ query = (
459
+ "SELECT department, revenue, expenses, headcount "
460
+ "FROM daily_reports "
461
+ f"WHERE report_date = '{scenario.target_date}' "
462
+ "ORDER BY department"
463
+ )
464
+ sql_obs = env.step(
465
+ DataOpsAction(
466
+ action_type="ExecuteSQL",
467
+ payload={"query": query},
468
+ )
469
+ )
470
+ assert sql_obs.status == "success"
471
+ rows = sql_obs.sql_results
472
+ assert rows is not None
473
+
474
+ env.step(
475
+ DataOpsAction(
476
+ action_type="WriteFile",
477
+ payload={"filepath": "report_data.json", "content": json.dumps(rows)},
478
+ )
479
+ )
480
+ env.step(
481
+ DataOpsAction(
482
+ action_type="WriteFile",
483
+ payload={
484
+ "filepath": "format_report.py",
485
+ "content": _fixed_format_script(scenario.target_date),
486
+ },
487
+ )
488
+ )
489
+ run_obs = env.step(
490
+ DataOpsAction(
491
+ action_type="RunScript",
492
+ payload={"filepath": "format_report.py", "args": ["./report_data.json"]},
493
+ )
494
+ )
495
+ assert run_obs.status == "success"
496
+
497
+ env.step(
498
+ DataOpsAction(
499
+ action_type="SendEmail",
500
+ payload={
501
+ "to_email": scenario.recipient,
502
+ "subject": scenario.subject,
503
+ "body": (run_obs.stdout or "").strip(),
504
+ },
505
+ )
506
+ )
507
+
508
+ out = evaluate_task("task_3_hard_e2e", env)
509
+ assert out["score"] == 1.0
510
+ finally:
511
+ env.close()
512
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
513
+
514
+
515
+ def test_task_3_fabricated_email_only_scores_low() -> None:
516
+ env = DataOpsEnvironment()
517
+ env.reset(task_id="task_3_hard_e2e", seed=23)
518
+ scenario = env.scenario.task_3
519
+ assert scenario is not None
520
+ try:
521
+ fake_body = build_task_3_report(list(scenario.expected_rows), scenario.target_date)
522
+ email_obs = env.step(
523
+ DataOpsAction(
524
+ action_type="SendEmail",
525
+ payload={
526
+ "to_email": scenario.recipient,
527
+ "subject": scenario.subject,
528
+ "body": fake_body,
529
+ },
530
+ )
531
+ )
532
+ assert email_obs.status == "success"
533
+ out = evaluate_task("task_3_hard_e2e", env)
534
+ assert out["score"] <= 0.10
535
+ finally:
536
+ env.close()
537
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
538
+
539
+
540
+ def test_task_3_reading_formatter_source_awards_progress_signal() -> None:
541
+ env = DataOpsEnvironment()
542
+ env.reset(task_id="task_3_hard_e2e", seed=31)
543
+ try:
544
+ obs = env.step(
545
+ DataOpsAction(
546
+ action_type="ReadFile",
547
+ payload={"filepath": "format_report.py"},
548
+ )
549
+ )
550
+ assert obs.status == "success"
551
+ assert obs.reward is not None and obs.reward > 0
552
+ finally:
553
+ env.close()
554
+ shutil.rmtree(env.workspace_dir, ignore_errors=True)
555
+
556
+
557
+ def test_tasks_endpoint_exposes_manifest_metadata() -> None:
558
+ with TestClient(app) as client:
559
+ response = client.get("/tasks")
560
+ payload = response.json()
561
+ assert response.status_code == 200
562
+ assert len(payload["tasks"]) == 3
563
+ assert payload["tasks"][0]["difficulty"] == "easy"
564
+ assert "action_schema" in payload
565
+
566
+
567
+ def test_public_grader_hides_details_by_default(monkeypatch) -> None:
568
+ # Do not leak grader details when PUBLIC_GRADER_DETAILS is unset/false (ignore dev .env).
569
+ monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "false")
570
+ with TestClient(app) as client:
571
+ reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5})
572
+ assert reset.status_code == 200
573
+ grade = client.get("/grader")
574
+ assert grade.status_code == 200
575
+ payload = grade.json()
576
+ assert "score" in payload
577
+ assert "details" not in payload
tests/test_inference_api.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration-style checks for inference wiring and transport/session flows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+ from subprocess import CompletedProcess
10
+ from types import SimpleNamespace
11
+ from typing import Any
12
+
13
+ from fastapi.testclient import TestClient
14
+
15
+ _ROOT = Path(__file__).resolve().parents[1]
16
+ if str(_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(_ROOT))
18
+
19
+ import inference # noqa: E402
20
+ import env_loader # noqa: E402
21
+ import server.app as app_module # noqa: E402
22
+ from client import DataOpsEnvClient # noqa: E402
23
+ from models import DataOpsAction # noqa: E402
24
+
25
+
26
+ class _FakeResponse:
27
+ def __init__(self, payload: dict[str, Any]) -> None:
28
+ self._payload = payload
29
+
30
+ def raise_for_status(self) -> None:
31
+ return None
32
+
33
+ def json(self) -> dict[str, Any]:
34
+ return self._payload
35
+
36
+
37
+ class _FakeHTTPSession:
38
+ def __init__(self) -> None:
39
+ self.urls: list[str] = []
40
+ self._step_count = 0
41
+
42
+ def request(
43
+ self,
44
+ method: str,
45
+ url: str,
46
+ timeout: float | None = None,
47
+ **kwargs: Any,
48
+ ) -> _FakeResponse:
49
+ del method, timeout, kwargs
50
+ self.urls.append(url)
51
+ if url.endswith("/reset"):
52
+ self._step_count = 0
53
+ return _FakeResponse(
54
+ {
55
+ "observation": {
56
+ "status": "success",
57
+ "message": "Repair the ETL job.",
58
+ },
59
+ "reward": 0.0,
60
+ "done": False,
61
+ }
62
+ )
63
+ if url.endswith("/step"):
64
+ self._step_count += 1
65
+ return _FakeResponse(
66
+ {
67
+ "observation": {
68
+ "status": "success",
69
+ "message": "Read ok.",
70
+ },
71
+ "reward": 0.0,
72
+ "done": self._step_count >= 2,
73
+ }
74
+ )
75
+ if url.endswith("/grader/task_2_medium_syntax"):
76
+ return _FakeResponse({"score": 0.25})
77
+ raise AssertionError(f"Unexpected URL requested: {url}")
78
+
79
+
80
+ class _FakeChatCompletions:
81
+ def __init__(self, messages: list[Any]) -> None:
82
+ self._messages = iter(messages)
83
+
84
+ def create(self, **kwargs: Any) -> Any:
85
+ del kwargs
86
+ message = next(self._messages)
87
+ return SimpleNamespace(choices=[SimpleNamespace(message=message)])
88
+
89
+
90
+ class _FakeClient:
91
+ def __init__(self, messages: list[Any]) -> None:
92
+ self.chat = SimpleNamespace(completions=_FakeChatCompletions(messages))
93
+ self.base_url = "https://model.local/v1"
94
+
95
+
96
+ def _tool_message(name: str, arguments: dict[str, Any]) -> Any:
97
+ return SimpleNamespace(
98
+ tool_calls=[
99
+ SimpleNamespace(
100
+ id="call-1",
101
+ function=SimpleNamespace(
102
+ name=name,
103
+ arguments=json.dumps(arguments),
104
+ ),
105
+ )
106
+ ]
107
+ )
108
+
109
+
110
+ def test_inference_run_task_uses_env_base_url(monkeypatch, capsys) -> None:
111
+ fake_http = _FakeHTTPSession()
112
+ fake_client = _FakeClient(
113
+ [
114
+ _tool_message("read_file", {"filepath": "broken_pipeline.py"}),
115
+ SimpleNamespace(tool_calls=[]),
116
+ _tool_message("invoke_python", {"filepath": "broken_pipeline.py", "args": []}),
117
+ ]
118
+ )
119
+ monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
120
+ monkeypatch.setattr(inference, "API_BASE_URL", "https://model.local/v1")
121
+ monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
122
+
123
+ score = inference.run_task(
124
+ fake_client,
125
+ fake_http,
126
+ "task_2_medium_syntax",
127
+ max_turns=4,
128
+ seed=3,
129
+ )
130
+
131
+ assert score == 0.25
132
+ assert fake_http.urls
133
+ assert all(url.startswith("http://env.local") for url in fake_http.urls)
134
+ assert all("model.local" not in url for url in fake_http.urls)
135
+
136
+ stdout = capsys.readouterr().out
137
+ assert "[START]" in stdout
138
+ assert "[STEP]" in stdout
139
+ assert "[END]" in stdout
140
+ assert "success=false" in stdout
141
+
142
+
143
+ def test_inference_emits_grader_details_to_stderr_when_enabled(monkeypatch, capsys) -> None:
144
+ class _DetailedHTTPSession(_FakeHTTPSession):
145
+ def request(
146
+ self,
147
+ method: str,
148
+ url: str,
149
+ timeout: float | None = None,
150
+ **kwargs: Any,
151
+ ) -> _FakeResponse:
152
+ if url.endswith("/grader/task_2_medium_syntax"):
153
+ return _FakeResponse(
154
+ {
155
+ "task_id": "task_2_medium_syntax",
156
+ "score": 0.25,
157
+ "details": {"reason": "Visible repair only"},
158
+ }
159
+ )
160
+ return super().request(method, url, timeout=timeout, **kwargs)
161
+
162
+ fake_http = _DetailedHTTPSession()
163
+ fake_client = _FakeClient(
164
+ [_tool_message("read_file", {"filepath": "broken_pipeline.py"})]
165
+ )
166
+ monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "true")
167
+ monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
168
+ monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
169
+
170
+ inference.run_task(
171
+ fake_client,
172
+ fake_http,
173
+ "task_2_medium_syntax",
174
+ max_turns=1,
175
+ seed=3,
176
+ )
177
+
178
+ stderr = capsys.readouterr().err.strip()
179
+ assert stderr
180
+ assert json.loads(stderr)["details"]["reason"] == "Visible repair only"
181
+
182
+
183
+ def test_baseline_endpoint_passes_env_base_url(monkeypatch) -> None:
184
+ captured: dict[str, Any] = {}
185
+
186
+ def fake_run(
187
+ command: list[str],
188
+ *,
189
+ cwd: str,
190
+ capture_output: bool,
191
+ text: bool,
192
+ timeout: float,
193
+ env: dict[str, str],
194
+ ) -> CompletedProcess[str]:
195
+ captured["command"] = command
196
+ captured["cwd"] = cwd
197
+ captured["capture_output"] = capture_output
198
+ captured["text"] = text
199
+ captured["timeout"] = timeout
200
+ captured["env"] = env
201
+ stdout = "\n".join(
202
+ [
203
+ "[START] task=task_1_easy_anomaly env=dataops_env model=fake-model",
204
+ "[END] success=true steps=1 score=1.000 rewards=1.00",
205
+ json.dumps(
206
+ {
207
+ "scores": {"task_1_easy_anomaly": 1.0},
208
+ "grades": {
209
+ "task_1_easy_anomaly": {
210
+ "task_id": "task_1_easy_anomaly",
211
+ "score": 1.0,
212
+ "details": {"reason": "Perfect"},
213
+ }
214
+ },
215
+ "average": 1.0,
216
+ "model": "fake-model",
217
+ "metadata": {"env_base_url": "http://127.0.0.1:7860"},
218
+ }
219
+ ),
220
+ ]
221
+ )
222
+ stderr = json.dumps({"task_id": "task_1_easy_anomaly", "score": 1.0})
223
+ return CompletedProcess(command, 0, stdout=stdout, stderr=stderr)
224
+
225
+ monkeypatch.setenv("API_KEY", "test-key")
226
+ monkeypatch.delenv("ENV_BASE_URL", raising=False)
227
+ monkeypatch.setattr(app_module.subprocess, "run", fake_run)
228
+
229
+ with TestClient(app_module.app) as client:
230
+ response = client.post(
231
+ "/baseline",
232
+ json={
233
+ "task_ids": ["task_1_easy_anomaly"],
234
+ "seed": 7,
235
+ "max_turns": 5,
236
+ },
237
+ )
238
+
239
+ assert response.status_code == 200
240
+ assert "[START] task=task_1_easy_anomaly" in response.json()["stdout"]
241
+ assert response.json()["stderr"] == json.dumps({"task_id": "task_1_easy_anomaly", "score": 1.0})
242
+ assert response.json()["scores"]["task_1_easy_anomaly"] == 1.0
243
+ assert response.json()["grades"]["task_1_easy_anomaly"]["details"]["reason"] == "Perfect"
244
+ assert captured["env"]["ENV_BASE_URL"] == "http://127.0.0.1:7860"
245
+ assert "--seed" in captured["command"]
246
+ assert "--max-turns" in captured["command"]
247
+ assert "--task" in captured["command"]
248
+
249
+
250
+ def test_inference_default_api_base_url_uses_google_for_api_key(
251
+ monkeypatch,
252
+ ) -> None:
253
+ monkeypatch.setenv("API_KEY", "test-key")
254
+ monkeypatch.delenv("HF_TOKEN", raising=False)
255
+ monkeypatch.delenv("API_BASE_URL", raising=False)
256
+
257
+ assert (
258
+ inference._resolve_api_base_url()
259
+ == inference.DEFAULT_GOOGLE_OPENAI_BASE_URL
260
+ )
261
+
262
+
263
+ def test_inference_default_api_base_url_uses_hf_router_for_hf_token(
264
+ monkeypatch,
265
+ ) -> None:
266
+ monkeypatch.setenv("HF_TOKEN", "test-token")
267
+ monkeypatch.delenv("API_KEY", raising=False)
268
+ monkeypatch.delenv("API_BASE_URL", raising=False)
269
+
270
+ assert inference._resolve_api_base_url() == inference.DEFAULT_HF_OPENAI_BASE_URL
271
+
272
+
273
+ def test_session_id_header_can_resume_http_episode() -> None:
274
+ with TestClient(app_module.app) as client:
275
+ reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5})
276
+ assert reset.status_code == 200
277
+ session_id = reset.headers["X-Session-ID"]
278
+ client.cookies.clear()
279
+
280
+ state = client.get("/state", headers={"X-Session-ID": session_id})
281
+ assert state.status_code == 200
282
+ payload = state.json()
283
+ assert payload["task_id"] == "task_1_easy_anomaly"
284
+ assert payload["seed"] == 5
285
+
286
+
287
+ def test_reset_replaces_unknown_client_supplied_session_id() -> None:
288
+ with TestClient(app_module.app) as client:
289
+ reset = client.post(
290
+ "/reset?task_id=task_1_easy_anomaly",
291
+ headers={"X-Session-ID": "attacker-chosen-session"},
292
+ json={"seed": 4},
293
+ )
294
+ assert reset.status_code == 200
295
+ issued_session_id = reset.headers["X-Session-ID"]
296
+ assert issued_session_id != "attacker-chosen-session"
297
+
298
+ client.cookies.clear()
299
+ forged_state = client.get("/state", headers={"X-Session-ID": "attacker-chosen-session"})
300
+ assert forged_state.status_code == 400
301
+
302
+ restored_state = client.get("/state", headers={"X-Session-ID": issued_session_id})
303
+ assert restored_state.status_code == 200
304
+ assert restored_state.json()["seed"] == 4
305
+
306
+
307
+ def test_websocket_reset_state_and_step_flow() -> None:
308
+ with TestClient(app_module.app) as client:
309
+ with client.websocket_connect("/ws") as websocket:
310
+ websocket.send_json(
311
+ {
312
+ "type": "reset",
313
+ "data": {"task_id": "task_1_easy_anomaly", "seed": 3},
314
+ }
315
+ )
316
+ reset_payload = websocket.receive_json()
317
+ assert reset_payload["data"]["observation"]["status"] == "success"
318
+
319
+ websocket.send_json({"type": "state"})
320
+ state_payload = websocket.receive_json()
321
+ assert state_payload["data"]["task_id"] == "task_1_easy_anomaly"
322
+
323
+ websocket.send_json(
324
+ {
325
+ "type": "step",
326
+ "data": {
327
+ "action_type": "ExecuteSQL",
328
+ "payload": {
329
+ "query": (
330
+ "SELECT id, amount FROM transactions "
331
+ "WHERE amount IS NULL ORDER BY id"
332
+ )
333
+ },
334
+ },
335
+ }
336
+ )
337
+ step_payload = websocket.receive_json()
338
+ assert step_payload["data"]["observation"]["status"] == "success"
339
+ assert step_payload["data"]["observation"]["sql_results"]
340
+
341
+ websocket.send_json({"type": "close", "data": {}})
342
+
343
+
344
+ def test_http_client_overlays_top_level_reward_and_done() -> None:
345
+ class _FakeSession:
346
+ def post(self, url: str, **kwargs: Any) -> _FakeResponse:
347
+ del kwargs
348
+ if url.endswith("/reset"):
349
+ return _FakeResponse(
350
+ {
351
+ "observation": {"status": "success", "message": "ready"},
352
+ "reward": 0.0,
353
+ "done": False,
354
+ }
355
+ )
356
+ if url.endswith("/step"):
357
+ return _FakeResponse(
358
+ {
359
+ "observation": {"status": "success", "message": "ok"},
360
+ "reward": 0.25,
361
+ "done": True,
362
+ }
363
+ )
364
+ raise AssertionError(f"Unexpected URL requested: {url}")
365
+
366
+ def get(self, url: str, **kwargs: Any) -> _FakeResponse:
367
+ del kwargs
368
+ raise AssertionError(f"Unexpected URL requested: {url}")
369
+
370
+ def close(self) -> None:
371
+ return None
372
+
373
+ client = DataOpsEnvClient(base_url="http://env.local")
374
+ client._session = _FakeSession()
375
+ try:
376
+ reset_obs = client.reset(task_id="task_1_easy_anomaly", seed=5)
377
+ assert reset_obs.reward == 0.0
378
+ assert reset_obs.done is False
379
+
380
+ step_obs = client.step(
381
+ DataOpsAction(
382
+ action_type="ExecuteSQL",
383
+ payload={"query": "SELECT 1"},
384
+ )
385
+ )
386
+ assert step_obs.reward == 0.25
387
+ assert step_obs.done is True
388
+ finally:
389
+ client.close()
390
+
391
+
392
+ def test_env_loader_uses_root_env_to_find_secondary_env_file(
393
+ tmp_path: Path, monkeypatch
394
+ ) -> None:
395
+ monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path)
396
+ monkeypatch.delenv("PUBLIC_GRADER_DETAILS", raising=False)
397
+ monkeypatch.delenv("MODEL_NAME", raising=False)
398
+
399
+ (tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8")
400
+ (tmp_path / ".env.dev").write_text(
401
+ "PUBLIC_GRADER_DETAILS=true\nMODEL_NAME=debug-model\n",
402
+ encoding="utf-8",
403
+ )
404
+
405
+ env_loader.load_env()
406
+
407
+ assert os.getenv("PUBLIC_GRADER_DETAILS") == "true"
408
+ assert os.getenv("MODEL_NAME") == "debug-model"
uv.lock ADDED
The diff for this file is too large to render. See raw diff