Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +41 -0
- README.md +248 -6
- __init__.py +14 -0
- client.py +93 -0
- data/__init__.py +0 -0
- data/init_db.py +130 -0
- env_loader.py +82 -0
- inference.py +589 -0
- models.py +113 -0
- openenv.yaml +36 -0
- pyproject.toml +42 -0
- server/__init__.py +5 -0
- server/__main__.py +6 -0
- server/app.py +530 -0
- server/dataops_env_environment.py +839 -0
- server/grading.py +557 -0
- server/requirements.txt +10 -0
- server/safe_exec.py +195 -0
- server/session_manager.py +128 -0
- server/task_specs.py +773 -0
- tests/test_grading.py +577 -0
- tests/test_inference_api.py +408 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim AS builder
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1
|
| 5 |
+
|
| 6 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
COPY pyproject.toml uv.lock README.md ./
|
| 11 |
+
RUN uv sync --frozen --no-install-project --no-dev
|
| 12 |
+
|
| 13 |
+
COPY __init__.py client.py env_loader.py inference.py models.py openenv.yaml ./
|
| 14 |
+
COPY server ./server
|
| 15 |
+
COPY data ./data
|
| 16 |
+
RUN uv sync --frozen --no-dev
|
| 17 |
+
|
| 18 |
+
FROM python:3.12-slim
|
| 19 |
+
|
| 20 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 21 |
+
PYTHONUNBUFFERED=1 \
|
| 22 |
+
HOST=0.0.0.0 \
|
| 23 |
+
PORT=7860 \
|
| 24 |
+
PATH="/app/.venv/bin:$PATH"
|
| 25 |
+
|
| 26 |
+
WORKDIR /app
|
| 27 |
+
|
| 28 |
+
RUN useradd -m appuser
|
| 29 |
+
COPY --from=builder --chown=appuser:appuser /app /app
|
| 30 |
+
|
| 31 |
+
RUN mkdir -p /app/workspace && chown -R appuser:appuser /app
|
| 32 |
+
|
| 33 |
+
USER appuser
|
| 34 |
+
|
| 35 |
+
HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
|
| 36 |
+
CMD python -c "import os, urllib.request; urllib.request.urlopen('http://127.0.0.1:' + os.getenv('PORT', '7860') + '/health')" || exit 1
|
| 37 |
+
|
| 38 |
+
EXPOSE 7860
|
| 39 |
+
|
| 40 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 41 |
+
CMD ["python", "-m", "server"]
|
README.md
CHANGED
|
@@ -1,10 +1,252 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DataOpsEnv
|
| 3 |
+
emoji: 🧩
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
short_description: OpenEnv DataOps — SQLite, ETL repair, three graded tasks.
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
base_path: /web
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# DataOpsEnv
|
| 15 |
+
|
| 16 |
+
[Overview](#environment-description-and-motivation) · [Tasks](#tasks-descriptions-and-expected-difficulty) · [Setup and run](#setup-and-usage) · [Baseline scores](#baseline-scores) · [Hugging Face Spaces](#hugging-face-spaces) · [HTTP API](#api-reference) · [Tests](#tests)
|
| 17 |
+
|
| 18 |
+
## Environment description and motivation
|
| 19 |
+
|
| 20 |
+
**DataOpsEnv** is an OpenEnv-compliant benchmark in which an agent performs data-engineering work: inspecting a small **SQLite** warehouse, **repairing Python ETL scripts**, and completing an **end-to-end reporting incident** (extract data, fix a formatter, send a mock email). Episodes are **seeded** (`reset` may include `seed`) so scenarios are **reproducible**; each HTTP session receives an **isolated workspace and database**.
|
| 21 |
+
|
| 22 |
+
Many agent benchmarks are game-like or shallow. Data cleaning, script debugging, and stakeholder communication reflect **real workflows**; this environment exercises multi-step tool use, constraint respect, and verifiable outcomes rather than single-shot question answering.
|
| 23 |
+
|
| 24 |
+
**Implementation:** FastAPI (`server/app.py`), environment logic (`server/dataops_env_environment.py`), terminal graders (`server/grading.py`), scenario definitions (`server/task_specs.py`, `data/init_db.py`), Pydantic types (`models.py`), OpenEnv manifest (`openenv.yaml`).
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Action space
|
| 29 |
+
|
| 30 |
+
Each step submits JSON: `{"action": {"action_type": "<type>", "payload": { ... }}}`. Payloads are validated per task (allowed files, SQL policy, email enabled only on the hard task).
|
| 31 |
+
|
| 32 |
+
| `action_type` | Payload fields | Role |
|
| 33 |
+
| ------------- | ---------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
|
| 34 |
+
| `ExecuteSQL` | `query` (string, 1–2000 chars) | Run task-scoped SQL against the episode SQLite DB. |
|
| 35 |
+
| `ReadFile` | `filepath` (string, 1–255 chars) | Read an allowed file from the episode workspace. |
|
| 36 |
+
| `WriteFile` | `filepath`, `content` (content ≤ 1M chars) | Overwrite an allowed workspace file. |
|
| 37 |
+
| `RunScript` | `filepath` (must be `*.py` basename), `args` (optional list of strings, ≤ 20 args, each ≤ 500 chars) | Execute a Python script in the workspace with optional CLI args. |
|
| 38 |
+
| `SendEmail` | `to_email`, `subject`, `body` | Queue a mock email (used for the hard task). |
|
| 39 |
+
|
| 40 |
+
Machine-readable schema: **`GET /schema`** → `action`, or **`GET /tasks`** → `action_schema`.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Observation space
|
| 45 |
+
|
| 46 |
+
Each `step` / `reset` response includes an observation object (REST also exposes wrapper fields such as `reward` / `done`). The fields below describe the **DataOps** layer; the OpenEnv base also defines `done`, `reward`, and `metadata`.
|
| 47 |
+
|
| 48 |
+
| Field | Type | Meaning |
|
| 49 |
+
| ----------------------- | ------------------------ | ----------------------------------------------------------------- |
|
| 50 |
+
| `done` | boolean | Whether the episode has ended (step limit or terminal condition). |
|
| 51 |
+
| `reward` | number \| null | Shaped **step reward** after this transition (trajectory signal). |
|
| 52 |
+
| `metadata` | object | OpenEnv extension bucket (usually empty). |
|
| 53 |
+
| `status` | `"success"` \| `"error"` | Whether the action executed successfully. |
|
| 54 |
+
| `message` | string | Short human-readable summary. |
|
| 55 |
+
| `stdout` | string \| null | Captured stdout (e.g. script or file read). |
|
| 56 |
+
| `stderr` | string \| null | Captured stderr. |
|
| 57 |
+
| `sql_results` | list of objects \| null | Row dicts for successful `SELECT`-style outcomes. |
|
| 58 |
+
| `email_delivery_status` | string \| null | Mock send confirmation when applicable. |
|
| 59 |
+
| `step_count` | integer | Steps taken in the episode. |
|
| 60 |
+
| `max_steps` | integer | Episode step budget. |
|
| 61 |
+
|
| 62 |
+
**Terminal evaluation:** The **grader score** in **[0.0, 1.0]** is returned by **`GET /grader`** (or **`GET /grader/{task_id}`**) and reflects the **final** database, files, and outbox (and, for the hard task, **provenance** constraints). Hackathon-style evaluations typically treat the **grader** as the primary benchmark metric; step rewards remain a supplementary signal. Successful actions can still return **`reward=0.0`** when they neither improve grader state nor unlock a milestone.
|
| 63 |
+
|
| 64 |
+
Machine-readable schema: **`GET /schema`** → `observation`.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Tasks (descriptions and expected difficulty)
|
| 69 |
+
|
| 70 |
+
| Task ID | Expected difficulty | Description |
|
| 71 |
+
| ---------------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 72 |
+
| `task_1_easy_anomaly` | **Easy** | The `transactions` table contains valid rows and rows with **NULL** `amount`. The agent must **delete only** the corrupted rows and leave all valid rows **unchanged**, including legitimate seeded zero-value or negative non-null adjustments. |
|
| 73 |
+
| `task_2_medium_syntax` | **Medium** | `broken_pipeline.py` is a seeded ETL normalization job with broken filtering, priority logic, and ordering. The agent must **read**, **patch**, and **run** the script so **`process_data_stream`** produces the correct downstream-ready records on both visible and hidden seeded batches. |
|
| 74 |
+
| `task_3_hard_e2e` | **Hard** | **End-to-end incident:** query the correct **`daily_reports`** slice for the **scenario date**, persist results as **`report_data.json`**, **repair** **`format_report.py`**, **run** it on that JSON, then **send exactly one** email whose **body matches** the formatter output, with scenario-specific **recipient** and **subject**. |
|
| 75 |
+
|
| 76 |
+
Task list, difficulty labels, and allowed actions per task: **`GET /tasks`** and **`openenv.yaml`**.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Setup and usage
|
| 81 |
+
|
| 82 |
+
**Prerequisites:** Python **3.12+**, **[uv](https://docs.astral.sh/uv/)**.
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
uv sync
|
| 86 |
+
cp .env.example .env.dev
|
| 87 |
+
printf 'ENV_FILE=.env.dev\n' > .env
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Repo-root **`.env`** selects the active secondary env file. Use **`.env.dev`** for local runtime/model configuration. Hosted deployments that inject environment variables directly can skip both files.
|
| 91 |
+
|
| 92 |
+
**Run the server** from the repository root so **`HOST`**, **`PORT`**, and **`DEBUG`** from the active env file are honored:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
uv run python -m server
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
Clients reuse the **`Set-Cookie`** session cookie or **`X-Session-ID`** header from **`POST /reset`** on **`/step`**, **`/state`**, and **`/grader`**.
|
| 99 |
+
|
| 100 |
+
**OpenEnv packaging:**
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
uv run openenv validate
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
**Docker:**
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
printf 'ENV_FILE=.env.dev\n' > .env
|
| 110 |
+
bash build_and_run_image.sh
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
The helper script reads repo-root **`.env`** only to resolve **`ENV_FILE`**, then passes that secondary file to `docker run --env-file ...`. The container does **not** receive a merged view of repo-root **`.env`** plus the secondary file. Keeping `.env*` out of the image is intentional; runtime configuration is injected from the host.
|
| 114 |
+
|
| 115 |
+
**Baseline inference (local):**
|
| 116 |
+
|
| 117 |
+
| Variable | Purpose |
|
| 118 |
+
| ---------------------- | -------------------------------------------------------------------------------------- |
|
| 119 |
+
| `ENV_BASE_URL` | Environment server URL (default `http://127.0.0.1:$PORT`, with `PORT=7860` by default) |
|
| 120 |
+
| `API_KEY` / `HF_TOKEN` | Exactly one model access credential source |
|
| 121 |
+
| `API_BASE_URL` | Optional model provider base URL override |
|
| 122 |
+
| `MODEL_NAME` | Optional Chat model ID |
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
export ENV_BASE_URL=http://127.0.0.1:7860
|
| 126 |
+
uv run python inference.py --seed 7 --max-turns 12
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
If **`API_BASE_URL`** is unset, `inference.py` defaults to Google's OpenAI-compatible Gemini endpoint for **`API_KEY`** and Hugging Face's router for **`HF_TOKEN`**.
|
| 130 |
+
|
| 131 |
+
Flags: `--task` (repeatable), `--seed`, `--max-turns`, `--json-scores` (emits one JSON object on stdout after the harness lines, including raw grader payloads when available). When `PUBLIC_GRADER_DETAILS=true` and the grader API exposes details, `inference.py` also writes the per-task grader payloads to `stderr`.
|
| 132 |
+
|
| 133 |
+
**`POST /baseline`** runs the same script inside the server process; optional JSON body: `task_ids`, `seed`, `max_turns`. If **`ADMIN_API_KEY`** is unset, the route is open. If **`ADMIN_API_KEY`** is set, callers must send **`X-Admin-Key`**. If **`ENV_BASE_URL`** is unset, the server injects **`http://127.0.0.1:$PORT`** into the child process automatically.
|
| 134 |
+
|
| 135 |
+
Agent-executed Python scripts run with a stripped environment, bounded resources, and capped captured output so task verification does not inherit model-provider secrets from the server process.
|
| 136 |
+
|
| 137 |
+
**Minimal HTTP smoke test:**
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
curl -c cookies.txt -X POST 'http://127.0.0.1:7860/reset?task_id=task_1_easy_anomaly' \
|
| 141 |
+
-H 'Content-Type: application/json' \
|
| 142 |
+
-d '{"seed": 7}'
|
| 143 |
+
|
| 144 |
+
curl -b cookies.txt -X POST 'http://127.0.0.1:7860/step' \
|
| 145 |
+
-H 'Content-Type: application/json' \
|
| 146 |
+
-d '{"action":{"action_type":"ExecuteSQL","payload":{"query":"DELETE FROM transactions WHERE amount IS NULL"}}}'
|
| 147 |
+
|
| 148 |
+
curl -b cookies.txt 'http://127.0.0.1:7860/grader'
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
By default **`/grader`** returns **`task_id`** and **`score`** only. Full grader **`details`** require **`PUBLIC_GRADER_DETAILS=true`** or a valid **`X-Admin-Key`** when **`ADMIN_API_KEY`** is set. This does **not** change the mandatory `[START]` / `[STEP]` / `[END]` lines from `inference.py`; it affects the grader API, the optional trailing JSON emitted by `--json-scores`, and the captured `stderr` payloads written by `inference.py`.
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## Baseline scores
|
| 156 |
+
|
| 157 |
+
All figures are **terminal grader** scores in **[0.0, 1.0]**. Scores depend on provider, model revision, temperature, and `seed`.
|
| 158 |
+
|
| 159 |
+
### Null baseline (no agent actions)
|
| 160 |
+
|
| 161 |
+
| Condition | `task_1` | `task_2` | `task_3` | Avg |
|
| 162 |
+
| ---------------------------------------------------- | -------- | -------- | -------- | ---- |
|
| 163 |
+
| `reset` only (`seed=7`), then grader; **no** `/step` | 0.00 | 0.00 | 0.00 | 0.00 |
|
| 164 |
+
|
| 165 |
+
### Reference tool-calling baseline
|
| 166 |
+
|
| 167 |
+
`[END] success=true` in the harness logs means the terminal grader reached **1.0** for that task.
|
| 168 |
+
|
| 169 |
+
| Model | Seed | `task_1_easy_anomaly` | `task_2_medium_syntax` | `task_3_hard_e2e` | Average |
|
| 170 |
+
| ------------------------------- | ---- | --------------------- | ---------------------- | ----------------- | ------- |
|
| 171 |
+
| `gemini-3.1-flash-lite-preview` | 7 | 1.00 | 1.00 | 1.00 | 1.00 |
|
| 172 |
+
|
| 173 |
+
**Reproducing a baseline run:** With the API server running locally on `7860` and model credentials configured, run:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
export MODEL_NAME=gemini-3.1-flash-lite-preview
|
| 177 |
+
export ENV_BASE_URL=http://127.0.0.1:7860
|
| 178 |
+
uv run python inference.py --seed 7 --max-turns 12 --json-scores
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
The final line of stdout is a single JSON object with **`scores`**, **`grades`**, **`average`**, **`model`**, and **`metadata`**.
|
| 182 |
+
|
| 183 |
+
---
|
| 184 |
+
|
| 185 |
+
## Hugging Face Spaces
|
| 186 |
+
|
| 187 |
+
There are two methods for running the baseline against a deployed Hugging Face Space:
|
| 188 |
+
|
| 189 |
+
1. Running **`inference.py`** externally against the public Space URL:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
export ENV_BASE_URL=https://visheshrathi-dataops-env.hf.space
|
| 193 |
+
uv run python inference.py --seed 7 --max-turns 12 --json-scores
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
In this mode, the Space only needs to expose the environment API (`/reset`, `/step`, `/grader`, `/tasks`, `/schema`, `/health`, `/metadata`, `/ws`, `/mcp`). Model credentials are provided on the machine that runs **`inference.py`**, not on the Space.
|
| 197 |
+
|
| 198 |
+
2. Hitting **`/baseline`** API with a `POST` request:
|
| 199 |
+
|
| 200 |
+
```bash
|
| 201 |
+
curl -X POST 'https://visheshrathi-dataops-env.hf.space/baseline' \
|
| 202 |
+
-H 'Content-Type: application/json' \
|
| 203 |
+
-d '{"seed": 7, "max_turns": 12}'
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
In this mode, the Space itself executes **`inference.py`**. Configure one model credential source on the Space (**`API_KEY`** or **`HF_TOKEN`**). **`MODEL_NAME`** and **`API_BASE_URL`** are optional overrides. **`ENV_BASE_URL`** is not required for **`POST /baseline`** because the server injects **`http://127.0.0.1:$PORT`** when it launches the child `inference.py` process. If **`ADMIN_API_KEY`** is unset, **`POST /baseline`** is open; if it is set, callers must send **`X-Admin-Key`**.
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## API reference
|
| 211 |
+
|
| 212 |
+
| Method | Path | Purpose |
|
| 213 |
+
| ------ | -------------------- | ------------------------------------------------------------- |
|
| 214 |
+
| GET | `/health` | Liveness |
|
| 215 |
+
| GET | `/metadata` | Name, description, version, task count |
|
| 216 |
+
| GET | `/schema` | JSON Schemas: action, observation, state |
|
| 217 |
+
| GET | `/tasks` | Tasks + action/observation/state schemas |
|
| 218 |
+
| POST | `/mcp` | Minimal JSON-RPC tool-list compatibility stub |
|
| 219 |
+
| POST | `/reset?task_id=...` | New episode; body may include `seed`, `episode_id` |
|
| 220 |
+
| POST | `/step` | One action; optional `timeout_s` |
|
| 221 |
+
| GET | `/state` | Episode state (`task_id`, `seed`, …) |
|
| 222 |
+
| GET | `/grader` | Terminal score for active task |
|
| 223 |
+
| GET | `/grader/{task_id}` | Same; `task_id` must match the active task |
|
| 224 |
+
| POST | `/baseline` | Subprocess baseline (see [Setup and usage](#setup-and-usage)) |
|
| 225 |
+
| WS | `/ws` | OpenEnv WebSocket session |
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
## Environment variables (server / container)
|
| 230 |
+
|
| 231 |
+
| Variable | Purpose |
|
| 232 |
+
| ------------------------ | --------------------------------------------------------------------------------------------- |
|
| 233 |
+
| `HOST` | Listen host used by `python -m server` and the container entrypoint |
|
| 234 |
+
| `PORT` | Listen port used by `python -m server` and the container entrypoint |
|
| 235 |
+
| `DEBUG` | Enables reload for local `python -m server` runs |
|
| 236 |
+
| `ENV_FILE` | Repo-relative dotenv loaded after `.env` (override) |
|
| 237 |
+
| `HTTP_SESSION_TIMEOUT_S` | HTTP session idle TTL; max wall time for **`POST /baseline`** child |
|
| 238 |
+
| `MAX_HTTP_SESSIONS` | Concurrent HTTP sessions cap |
|
| 239 |
+
| `MAX_WS_SESSIONS` | Concurrent WebSocket sessions cap |
|
| 240 |
+
| `ADMIN_API_KEY` | When set, protects **`POST /baseline`** and lets **`X-Admin-Key`** unlock full grader details |
|
| 241 |
+
| `PUBLIC_GRADER_DETAILS` | If `true`, public **`/grader`** and **`/grader/{task_id}`** responses include **`details`** |
|
| 242 |
+
| `COOKIE_SECURE` | Set `Secure` on session cookies (HTTPS) |
|
| 243 |
+
| `CORS_ALLOW_ORIGINS` | Comma-separated origins; empty disables permissive CORS (recommended default) |
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## Tests
|
| 248 |
+
|
| 249 |
+
```bash
|
| 250 |
+
uv sync --extra dev
|
| 251 |
+
uv run pytest -q
|
| 252 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataOps Environment — OpenEnv-compliant enterprise data pipeline remediation environment."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from .client import DataOpsEnv
|
| 5 |
+
from .models import DataOpsAction, DataOpsObservation
|
| 6 |
+
except ImportError: # pragma: no cover — flat imports when loaded as top-level __init__ (e.g. pytest)
|
| 7 |
+
from client import DataOpsEnv
|
| 8 |
+
from models import DataOpsAction, DataOpsObservation
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DataOpsAction",
|
| 12 |
+
"DataOpsObservation",
|
| 13 |
+
"DataOpsEnv",
|
| 14 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed clients for the DataOpsEnv environment."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
from openenv.core.client_types import StepResult
|
| 8 |
+
from openenv.core.env_client import EnvClient
|
| 9 |
+
|
| 10 |
+
from models import DataOpsAction, DataOpsObservation, DataOpsState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]):
|
| 14 |
+
"""Native OpenEnv WebSocket client for persistent sessions."""
|
| 15 |
+
|
| 16 |
+
def _step_payload(self, action: DataOpsAction) -> dict:
|
| 17 |
+
return action.model_dump()
|
| 18 |
+
|
| 19 |
+
def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]:
|
| 20 |
+
observation = DataOpsObservation(
|
| 21 |
+
**payload.get("observation", {}),
|
| 22 |
+
reward=payload.get("reward"),
|
| 23 |
+
done=payload.get("done", False),
|
| 24 |
+
)
|
| 25 |
+
return StepResult(
|
| 26 |
+
observation=observation,
|
| 27 |
+
reward=payload.get("reward"),
|
| 28 |
+
done=payload.get("done", False),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def _parse_state(self, payload: dict) -> DataOpsState:
|
| 32 |
+
return DataOpsState(**payload)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DataOpsEnvClient:
|
| 36 |
+
"""Compatibility HTTP client for the validator-facing REST API."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0
|
| 40 |
+
) -> None:
|
| 41 |
+
self.base_url = base_url.rstrip("/")
|
| 42 |
+
self.timeout = timeout
|
| 43 |
+
self._session = requests.Session()
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def _parse_observation(payload: dict) -> DataOpsObservation:
|
| 47 |
+
observation_payload = dict(payload.get("observation", {}))
|
| 48 |
+
if "reward" in payload:
|
| 49 |
+
observation_payload["reward"] = payload["reward"]
|
| 50 |
+
if "done" in payload:
|
| 51 |
+
observation_payload["done"] = payload["done"]
|
| 52 |
+
return DataOpsObservation(**observation_payload)
|
| 53 |
+
|
| 54 |
+
def reset(
|
| 55 |
+
self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None,
|
| 56 |
+
) -> DataOpsObservation:
|
| 57 |
+
resp = self._session.post(
|
| 58 |
+
f"{self.base_url}/reset",
|
| 59 |
+
params={"task_id": task_id},
|
| 60 |
+
json={"seed": seed},
|
| 61 |
+
timeout=self.timeout,
|
| 62 |
+
)
|
| 63 |
+
resp.raise_for_status()
|
| 64 |
+
return self._parse_observation(resp.json())
|
| 65 |
+
|
| 66 |
+
def step(self, action: DataOpsAction) -> DataOpsObservation:
|
| 67 |
+
resp = self._session.post(
|
| 68 |
+
f"{self.base_url}/step",
|
| 69 |
+
json={"action": action.model_dump()},
|
| 70 |
+
timeout=self.timeout,
|
| 71 |
+
)
|
| 72 |
+
resp.raise_for_status()
|
| 73 |
+
return self._parse_observation(resp.json())
|
| 74 |
+
|
| 75 |
+
def state(self) -> DataOpsState:
|
| 76 |
+
resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout)
|
| 77 |
+
resp.raise_for_status()
|
| 78 |
+
return DataOpsState(**resp.json())
|
| 79 |
+
|
| 80 |
+
def grade(self, task_id: Optional[str] = None) -> dict:
|
| 81 |
+
url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader"
|
| 82 |
+
resp = self._session.get(url, timeout=self.timeout)
|
| 83 |
+
resp.raise_for_status()
|
| 84 |
+
return resp.json()
|
| 85 |
+
|
| 86 |
+
def close(self) -> None:
|
| 87 |
+
self._session.close()
|
| 88 |
+
|
| 89 |
+
def __enter__(self) -> "DataOpsEnvClient":
|
| 90 |
+
return self
|
| 91 |
+
|
| 92 |
+
def __exit__(self, *args: object) -> None:
|
| 93 |
+
self.close()
|
data/__init__.py
ADDED
|
File without changes
|
data/init_db.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import sqlite3
|
| 5 |
+
|
| 6 |
+
from server.task_specs import TaskScenarioBundle, build_task_scenario
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace")
|
| 11 |
+
WORKSPACE_DIR = WORKSPACE_ROOT
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def setup_workspace(
|
| 15 |
+
workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None
|
| 16 |
+
) -> str:
|
| 17 |
+
"""Initialise an isolated episode workspace from the seeded scenario."""
|
| 18 |
+
target_workspace = workspace_dir or WORKSPACE_DIR
|
| 19 |
+
target_db_path = os.path.join(target_workspace, "mock_warehouse.db")
|
| 20 |
+
resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0)
|
| 21 |
+
os.makedirs(target_workspace, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
_clear_workspace(target_workspace)
|
| 24 |
+
_init_database(target_db_path, resolved_scenario)
|
| 25 |
+
_write_seeded_files(target_workspace, resolved_scenario)
|
| 26 |
+
|
| 27 |
+
logger.info(
|
| 28 |
+
"Workspace reset complete: task=%s seed=%s db=%s",
|
| 29 |
+
resolved_scenario.task_id,
|
| 30 |
+
resolved_scenario.seed,
|
| 31 |
+
target_db_path,
|
| 32 |
+
)
|
| 33 |
+
return target_db_path
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _clear_workspace(workspace_dir: str) -> None:
|
| 37 |
+
for entry in os.listdir(workspace_dir):
|
| 38 |
+
path = os.path.join(workspace_dir, entry)
|
| 39 |
+
try:
|
| 40 |
+
if os.path.isdir(path):
|
| 41 |
+
shutil.rmtree(path)
|
| 42 |
+
else:
|
| 43 |
+
os.remove(path)
|
| 44 |
+
except FileNotFoundError:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None:
|
| 49 |
+
conn = sqlite3.connect(db_path)
|
| 50 |
+
try:
|
| 51 |
+
c = conn.cursor()
|
| 52 |
+
c.execute(
|
| 53 |
+
"""
|
| 54 |
+
CREATE TABLE transactions (
|
| 55 |
+
id INTEGER PRIMARY KEY,
|
| 56 |
+
user_id INTEGER NOT NULL,
|
| 57 |
+
amount REAL,
|
| 58 |
+
status TEXT NOT NULL
|
| 59 |
+
)
|
| 60 |
+
"""
|
| 61 |
+
)
|
| 62 |
+
c.execute(
|
| 63 |
+
"""
|
| 64 |
+
CREATE TABLE daily_reports (
|
| 65 |
+
id INTEGER PRIMARY KEY,
|
| 66 |
+
report_date TEXT NOT NULL,
|
| 67 |
+
department TEXT NOT NULL,
|
| 68 |
+
revenue REAL NOT NULL,
|
| 69 |
+
expenses REAL NOT NULL,
|
| 70 |
+
headcount INTEGER NOT NULL
|
| 71 |
+
)
|
| 72 |
+
"""
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if scenario.task_1:
|
| 76 |
+
c.executemany(
|
| 77 |
+
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
|
| 78 |
+
[
|
| 79 |
+
(row["id"], row["user_id"], row["amount"], row["status"])
|
| 80 |
+
for row in scenario.task_1.all_rows
|
| 81 |
+
],
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
c.executemany(
|
| 85 |
+
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
|
| 86 |
+
[(1, 9000, 100.0, "success")],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if scenario.task_3:
|
| 90 |
+
c.executemany(
|
| 91 |
+
"INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)",
|
| 92 |
+
[
|
| 93 |
+
(
|
| 94 |
+
row["id"],
|
| 95 |
+
row["report_date"],
|
| 96 |
+
row["department"],
|
| 97 |
+
row["revenue"],
|
| 98 |
+
row["expenses"],
|
| 99 |
+
row["headcount"],
|
| 100 |
+
)
|
| 101 |
+
for row in scenario.task_3.all_rows
|
| 102 |
+
],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
conn.commit()
|
| 106 |
+
finally:
|
| 107 |
+
conn.close()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None:
|
| 111 |
+
if scenario.task_2:
|
| 112 |
+
with open(
|
| 113 |
+
os.path.join(workspace_dir, "broken_pipeline.py"),
|
| 114 |
+
"w",
|
| 115 |
+
encoding="utf-8",
|
| 116 |
+
) as f:
|
| 117 |
+
f.write(scenario.task_2.broken_script)
|
| 118 |
+
|
| 119 |
+
if scenario.task_3:
|
| 120 |
+
with open(
|
| 121 |
+
os.path.join(workspace_dir, "format_report.py"),
|
| 122 |
+
"w",
|
| 123 |
+
encoding="utf-8",
|
| 124 |
+
) as f:
|
| 125 |
+
f.write(scenario.task_3.broken_script)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
logging.basicConfig(level=logging.INFO)
|
| 130 |
+
setup_workspace()
|
env_loader.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load local env files, while allowing externally injected container env vars."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from dotenv import dotenv_values, load_dotenv
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent
|
| 12 |
+
_RUNTIME_ENV_KEYS = (
|
| 13 |
+
"ENV_FILE",
|
| 14 |
+
"HOST",
|
| 15 |
+
"PORT",
|
| 16 |
+
"DEBUG",
|
| 17 |
+
"ENV_BASE_URL",
|
| 18 |
+
"ADMIN_API_KEY",
|
| 19 |
+
"PUBLIC_GRADER_DETAILS",
|
| 20 |
+
"COOKIE_SECURE",
|
| 21 |
+
"HTTP_SESSION_TIMEOUT_S",
|
| 22 |
+
"CORS_ALLOW_ORIGINS",
|
| 23 |
+
"MAX_HTTP_SESSIONS",
|
| 24 |
+
"MAX_WS_SESSIONS",
|
| 25 |
+
"API_KEY",
|
| 26 |
+
"HF_TOKEN",
|
| 27 |
+
"MODEL_NAME",
|
| 28 |
+
"API_BASE_URL",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _has_external_runtime_config() -> bool:
|
| 33 |
+
return any(bool(os.getenv(key, "").strip()) for key in _RUNTIME_ENV_KEYS)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _resolve_env_file(env_file_name: str) -> Path:
|
| 37 |
+
env_path = (_PROJECT_ROOT / env_file_name).resolve()
|
| 38 |
+
try:
|
| 39 |
+
env_path.relative_to(_PROJECT_ROOT)
|
| 40 |
+
except ValueError as exc:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"ENV_FILE '{env_file_name}' must resolve inside the project root."
|
| 43 |
+
) from exc
|
| 44 |
+
return env_path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_env() -> None:
|
| 48 |
+
"""Read repo-root .env to locate the active secondary env file."""
|
| 49 |
+
dot_env = _PROJECT_ROOT / ".env"
|
| 50 |
+
if not dot_env.exists():
|
| 51 |
+
if _has_external_runtime_config():
|
| 52 |
+
logger.debug(
|
| 53 |
+
".env not found at %s — assuming runtime env vars were injected externally",
|
| 54 |
+
dot_env,
|
| 55 |
+
)
|
| 56 |
+
return
|
| 57 |
+
logger.warning(
|
| 58 |
+
".env not found at %s — local runs expect it to define ENV_FILE for the active env file",
|
| 59 |
+
dot_env,
|
| 60 |
+
)
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
load_dotenv(dot_env, override=False)
|
| 64 |
+
env_file_name = str(
|
| 65 |
+
(dotenv_values(dot_env).get("ENV_FILE") or os.getenv("ENV_FILE") or "")
|
| 66 |
+
).strip()
|
| 67 |
+
if not env_file_name:
|
| 68 |
+
logger.debug(".env did not specify ENV_FILE — no secondary file loaded")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
env_file = _resolve_env_file(env_file_name)
|
| 73 |
+
except ValueError as exc:
|
| 74 |
+
logger.warning("%s", exc)
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
if not env_file.exists():
|
| 78 |
+
logger.warning("ENV_FILE '%s' not found at %s", env_file_name, env_file)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
load_dotenv(env_file, override=True)
|
| 82 |
+
logger.debug("Loaded environment variables from %s", env_file)
|
inference.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataOps benchmark runner: drives the sandbox over HTTP (`/reset`, `/step`, `/grader`) with an OpenAI
|
| 3 |
+
tool-calling loop. Tool schemas are task-scoped (e.g. send_email only for the hard E2E task).
|
| 4 |
+
|
| 5 |
+
Flow per task: reset → chat completions (prefer `tool_choice="required"`) → validate tool args → POST each action →
|
| 6 |
+
append tool/observation messages until the env reports `done` or `max_turns` → GET grader score. Success is
|
| 7 |
+
derived from the score vs `SUCCESS_SCORE_THRESHOLD`.
|
| 8 |
+
|
| 9 |
+
Stdout is the harness protocol only: one `[START]`, one `[STEP]` per env step, one `[END]` (always). Use
|
| 10 |
+
`--json-scores` to append a single JSON object (scores, average, metadata) for `/baseline` ingestion.
|
| 11 |
+
|
| 12 |
+
CLI: `--task` (repeatable), `--seed`, `--max-turns`, `--json-scores`. The environment HTTP base URL comes from
|
| 13 |
+
`ENV_BASE_URL`, or if unset `http://127.0.0.1:$PORT` (default port 7860). Auth uses either `API_KEY` or
|
| 14 |
+
`HF_TOKEN`. `API_BASE_URL` is optional: when omitted, the runner defaults to Google's OpenAI-compatible Gemini
|
| 15 |
+
endpoint for `API_KEY` and Hugging Face's router for `HF_TOKEN`.
|
| 16 |
+
|
| 17 |
+
Library logging is disabled so parsers see only these lines.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import asyncio
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
import re
|
| 28 |
+
import sys
|
| 29 |
+
import zlib
|
| 30 |
+
from datetime import datetime, timezone
|
| 31 |
+
from typing import Any, Optional, Type
|
| 32 |
+
|
| 33 |
+
import requests
|
| 34 |
+
from openai import BadRequestError, OpenAI
|
| 35 |
+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
| 36 |
+
from pydantic import BaseModel, ValidationError
|
| 37 |
+
|
| 38 |
+
from env_loader import load_env
|
| 39 |
+
from models import (
|
| 40 |
+
ExecuteSQLPayload,
|
| 41 |
+
ReadFilePayload,
|
| 42 |
+
RunScriptPayload,
|
| 43 |
+
SendEmailPayload,
|
| 44 |
+
WriteFilePayload,
|
| 45 |
+
)
|
| 46 |
+
from server.task_specs import TASK_IDS, TASK_METADATA
|
| 47 |
+
|
| 48 |
+
# Silence all library logging (httpx, openai, urllib3, env_loader, etc.).
|
| 49 |
+
logging.disable(logging.CRITICAL)
|
| 50 |
+
|
| 51 |
+
load_env()
|
| 52 |
+
|
| 53 |
+
DEFAULT_GOOGLE_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
| 54 |
+
DEFAULT_HF_OPENAI_BASE_URL = "https://router.huggingface.co/v1"
|
| 55 |
+
|
| 56 |
+
_DEFAULT_PORT = int(os.getenv("PORT", "7860"))
|
| 57 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL") or f"http://127.0.0.1:{_DEFAULT_PORT}"
|
| 58 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "gemini-3.1-flash-lite-preview"
|
| 59 |
+
|
| 60 |
+
BENCHMARK = "dataops_env"
|
| 61 |
+
MAX_TURNS = 12
|
| 62 |
+
SUCCESS_SCORE_THRESHOLD = 1.0
|
| 63 |
+
|
| 64 |
+
_TOOL_HELP: dict[str, str] = {
|
| 65 |
+
"execute_sql": "execute_sql — SQL over the task warehouse (field: query).",
|
| 66 |
+
"read_file": "read_file — read a workspace file (field: filepath).",
|
| 67 |
+
"write_file": "write_file — overwrite a file (fields: filepath, content).",
|
| 68 |
+
"invoke_python": "invoke_python — run a Python script (fields: filepath, optional args).",
|
| 69 |
+
"send_email": "send_email — send email (fields: to_email, subject, body).",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
_ACTION_TO_TOOL: dict[str, str] = {
|
| 73 |
+
"ExecuteSQL": "execute_sql",
|
| 74 |
+
"ReadFile": "read_file",
|
| 75 |
+
"WriteFile": "write_file",
|
| 76 |
+
"RunScript": "invoke_python",
|
| 77 |
+
"SendEmail": "send_email",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _allowed_tool_names_csv(task_id: str) -> str:
|
| 82 |
+
order = (
|
| 83 |
+
"execute_sql",
|
| 84 |
+
"read_file",
|
| 85 |
+
"write_file",
|
| 86 |
+
"invoke_python",
|
| 87 |
+
"send_email",
|
| 88 |
+
)
|
| 89 |
+
allowed = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
|
| 90 |
+
return ", ".join(t for t in order if t in allowed)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _system_prompt_for_task(task_id: str) -> str:
|
| 94 |
+
lines = [
|
| 95 |
+
_TOOL_HELP[t]
|
| 96 |
+
for t in (
|
| 97 |
+
"execute_sql",
|
| 98 |
+
"read_file",
|
| 99 |
+
"write_file",
|
| 100 |
+
"invoke_python",
|
| 101 |
+
"send_email",
|
| 102 |
+
)
|
| 103 |
+
if t in {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
|
| 104 |
+
]
|
| 105 |
+
tools_block = "\n".join(f" - {line}" for line in lines)
|
| 106 |
+
return f"""\
|
| 107 |
+
You are an expert DataOps agent in a task-scoped benchmark. Only the tools listed below exist for this task — do not assume other actions are available.
|
| 108 |
+
|
| 109 |
+
Available tools:
|
| 110 |
+
{tools_block}
|
| 111 |
+
|
| 112 |
+
Rules:
|
| 113 |
+
- Always read files before modifying them when read_file is available.
|
| 114 |
+
- After writing a fix, run the script to verify it works when invoke_python is available.
|
| 115 |
+
- Be precise. Do not drop tables. Do not guess — inspect first.
|
| 116 |
+
- For tasks that include send_email, match subject and body to the task description exactly.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
TASK_PROMPTS = {
|
| 121 |
+
"task_1_easy_anomaly": (
|
| 122 |
+
"Solve the seeded cleanup task carefully. Inspect before mutating. Only NULL-amount rows are corrupted; preserve every non-null row exactly, including legitimate zero or negative adjustments."
|
| 123 |
+
),
|
| 124 |
+
"task_2_medium_syntax": (
|
| 125 |
+
"Solve the seeded script-repair task. Read the file, make the minimal correct fix, and verify with execution."
|
| 126 |
+
),
|
| 127 |
+
"task_3_hard_e2e": (
|
| 128 |
+
"Solve the seeded incident task end to end. Use SQL for the exact slice, write the exact JSON file, "
|
| 129 |
+
"repair the formatter, execute it, and email the exact generated report."
|
| 130 |
+
),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 135 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def log_step(
|
| 139 |
+
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 140 |
+
) -> None:
|
| 141 |
+
error_val = error if error else "null"
|
| 142 |
+
done_val = str(done).lower()
|
| 143 |
+
print(
|
| 144 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 145 |
+
flush=True,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 150 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 151 |
+
print(
|
| 152 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 153 |
+
flush=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _public_grader_details_enabled() -> bool:
|
| 158 |
+
return os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower() in {"1", "true", "yes"}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _emit_grader_details_to_stderr(grade: dict[str, Any]) -> None:
|
| 162 |
+
if not _public_grader_details_enabled():
|
| 163 |
+
return
|
| 164 |
+
if "details" not in grade:
|
| 165 |
+
return
|
| 166 |
+
print(json.dumps(grade, ensure_ascii=False), file=sys.stderr, flush=True)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _request_json(
|
| 170 |
+
http: requests.Session,
|
| 171 |
+
method: str,
|
| 172 |
+
path: str,
|
| 173 |
+
*,
|
| 174 |
+
timeout: float,
|
| 175 |
+
**kwargs: Any,
|
| 176 |
+
) -> dict[str, Any]:
|
| 177 |
+
response = http.request(method, f"{ENV_BASE_URL}{path}", timeout=timeout, **kwargs)
|
| 178 |
+
response.raise_for_status()
|
| 179 |
+
return response.json()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _build_tools(task_id: str) -> list[ChatCompletionToolParam]:
|
| 183 |
+
defs: dict[str, tuple[str, Type[BaseModel]]] = {
|
| 184 |
+
"execute_sql": (
|
| 185 |
+
"Run a task-scoped SQL query against the SQLite warehouse DB.",
|
| 186 |
+
ExecuteSQLPayload,
|
| 187 |
+
),
|
| 188 |
+
"read_file": ("Read a file in the workspace.", ReadFilePayload),
|
| 189 |
+
"write_file": ("Overwrite a file with new content.", WriteFilePayload),
|
| 190 |
+
"invoke_python": (
|
| 191 |
+
"Execute a Python script in the workspace (optional args).",
|
| 192 |
+
RunScriptPayload,
|
| 193 |
+
),
|
| 194 |
+
"send_email": ("Send a formatted email notification.", SendEmailPayload),
|
| 195 |
+
}
|
| 196 |
+
allowed_names = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions}
|
| 197 |
+
return [
|
| 198 |
+
{
|
| 199 |
+
"type": "function",
|
| 200 |
+
"function": {
|
| 201 |
+
"name": name,
|
| 202 |
+
"description": defs[name][0],
|
| 203 |
+
"parameters": defs[name][1].model_json_schema(),
|
| 204 |
+
},
|
| 205 |
+
}
|
| 206 |
+
for name in (
|
| 207 |
+
"execute_sql",
|
| 208 |
+
"read_file",
|
| 209 |
+
"write_file",
|
| 210 |
+
"invoke_python",
|
| 211 |
+
"send_email",
|
| 212 |
+
)
|
| 213 |
+
if name in allowed_names
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _tool_call_to_action(name: str, arguments: str) -> dict[str, Any]:
|
| 218 |
+
if name == "run_script":
|
| 219 |
+
name = "invoke_python"
|
| 220 |
+
mapping: dict[str, tuple[str, Type[BaseModel]]] = {
|
| 221 |
+
"execute_sql": ("ExecuteSQL", ExecuteSQLPayload),
|
| 222 |
+
"read_file": ("ReadFile", ReadFilePayload),
|
| 223 |
+
"write_file": ("WriteFile", WriteFilePayload),
|
| 224 |
+
"invoke_python": ("RunScript", RunScriptPayload),
|
| 225 |
+
"send_email": ("SendEmail", SendEmailPayload),
|
| 226 |
+
}
|
| 227 |
+
if name not in mapping:
|
| 228 |
+
raise ValueError(f"Unknown tool: {name}")
|
| 229 |
+
action_type, model = mapping[name]
|
| 230 |
+
data = json.loads(arguments) if (arguments or "").strip() else {}
|
| 231 |
+
payload = model.model_validate(data).model_dump()
|
| 232 |
+
return {"action_type": action_type, "payload": payload}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
_MALFORMED_TOOL = re.compile(
|
| 236 |
+
r"^([a-zA-Z_][a-zA-Z0-9_]*)[\s,=\(]+(\{.*\})\)?\s*$", re.DOTALL
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _normalize_tool_name_and_args(name: str, arguments: str) -> tuple[str, str]:
|
| 241 |
+
name = (name or "").strip()
|
| 242 |
+
arguments = (arguments or "").strip()
|
| 243 |
+
m = _MALFORMED_TOOL.match(name)
|
| 244 |
+
if m:
|
| 245 |
+
base, embedded = m.group(1).strip(), m.group(2).strip()
|
| 246 |
+
if not arguments:
|
| 247 |
+
return base, embedded
|
| 248 |
+
return name, arguments
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _action_from_tool_call(tc: Any) -> dict[str, Any]:
|
| 252 |
+
name, arguments = _normalize_tool_name_and_args(
|
| 253 |
+
tc.function.name or "", tc.function.arguments or ""
|
| 254 |
+
)
|
| 255 |
+
return _tool_call_to_action(name, arguments)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _action_str(action_payload: dict[str, Any]) -> str:
|
| 259 |
+
at = action_payload.get("action_type", "")
|
| 260 |
+
pl = action_payload.get("payload") or {}
|
| 261 |
+
raw = f"{at}({json.dumps(pl, ensure_ascii=False)})"
|
| 262 |
+
if len(raw) > 1200:
|
| 263 |
+
return raw[:600] + "..." + raw[-550:]
|
| 264 |
+
return raw
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _obs_error(obs: dict[str, Any]) -> Optional[str]:
|
| 268 |
+
if obs.get("status") != "error":
|
| 269 |
+
return None
|
| 270 |
+
msg = obs.get("message")
|
| 271 |
+
if isinstance(msg, str) and msg.strip():
|
| 272 |
+
return msg.replace("\n", " ").strip()
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _resolve_api_base_url() -> str:
|
| 277 |
+
explicit = os.getenv("API_BASE_URL", "").strip()
|
| 278 |
+
if explicit:
|
| 279 |
+
return explicit
|
| 280 |
+
if os.getenv("HF_TOKEN", "").strip():
|
| 281 |
+
return DEFAULT_HF_OPENAI_BASE_URL
|
| 282 |
+
return DEFAULT_GOOGLE_OPENAI_BASE_URL
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
API_BASE_URL = _resolve_api_base_url()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _openai_client() -> OpenAI:
|
| 289 |
+
key = (os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "").strip()
|
| 290 |
+
if not key:
|
| 291 |
+
print(
|
| 292 |
+
"[inference] Missing API_KEY or HF_TOKEN for model access.",
|
| 293 |
+
file=sys.stderr,
|
| 294 |
+
flush=True,
|
| 295 |
+
)
|
| 296 |
+
sys.exit(1)
|
| 297 |
+
return OpenAI(api_key=key, base_url=API_BASE_URL)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _llm_seed(env_seed: int | None, task_id: str) -> int | None:
|
| 301 |
+
if env_seed is None:
|
| 302 |
+
return None
|
| 303 |
+
mixed = (int(env_seed) * 1_000_003) ^ (zlib.crc32(task_id.encode()) & 0xFFFFFFFF)
|
| 304 |
+
return mixed & 0x7FFFFFFF
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _create_chat_completion(
|
| 308 |
+
client: OpenAI,
|
| 309 |
+
messages: list[ChatCompletionMessageParam],
|
| 310 |
+
tools: list[ChatCompletionToolParam],
|
| 311 |
+
*,
|
| 312 |
+
task_id: str,
|
| 313 |
+
env_seed: int | None,
|
| 314 |
+
) -> Any:
|
| 315 |
+
"""Prefer tool_choice=required so the model cannot end a turn without a tool call."""
|
| 316 |
+
kwargs: dict[str, Any] = {
|
| 317 |
+
"model": MODEL_NAME,
|
| 318 |
+
"messages": messages,
|
| 319 |
+
"tools": tools,
|
| 320 |
+
"parallel_tool_calls": False,
|
| 321 |
+
"temperature": 0,
|
| 322 |
+
"top_p": 1.0,
|
| 323 |
+
}
|
| 324 |
+
llm_seed = _llm_seed(env_seed, task_id)
|
| 325 |
+
if llm_seed is not None:
|
| 326 |
+
kwargs["seed"] = llm_seed
|
| 327 |
+
|
| 328 |
+
def _call(tool_choice: str) -> Any:
|
| 329 |
+
return client.chat.completions.create(**kwargs, tool_choice=tool_choice)
|
| 330 |
+
|
| 331 |
+
try:
|
| 332 |
+
return _call("required")
|
| 333 |
+
except BadRequestError as e:
|
| 334 |
+
err = str(e).lower()
|
| 335 |
+
if "seed" in err and llm_seed is not None:
|
| 336 |
+
kwargs.pop("seed", None)
|
| 337 |
+
try:
|
| 338 |
+
return _call("required")
|
| 339 |
+
except BadRequestError as e2:
|
| 340 |
+
err = str(e2).lower()
|
| 341 |
+
if not any(x in err for x in ("tool_choice", "required", "unsupported")):
|
| 342 |
+
raise
|
| 343 |
+
return _call("auto")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def run_task(
|
| 347 |
+
client: OpenAI,
|
| 348 |
+
http: requests.Session,
|
| 349 |
+
task_id: str,
|
| 350 |
+
*,
|
| 351 |
+
max_turns: int,
|
| 352 |
+
seed: int | None,
|
| 353 |
+
) -> float:
|
| 354 |
+
rewards: list[float] = []
|
| 355 |
+
steps_taken = 0
|
| 356 |
+
score = 0.0
|
| 357 |
+
success = False
|
| 358 |
+
|
| 359 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
tools = _build_tools(task_id)
|
| 363 |
+
names_csv = _allowed_tool_names_csv(task_id)
|
| 364 |
+
reset_resp = _request_json(
|
| 365 |
+
http,
|
| 366 |
+
"POST",
|
| 367 |
+
"/reset",
|
| 368 |
+
timeout=10,
|
| 369 |
+
params={"task_id": task_id},
|
| 370 |
+
json={} if seed is None else {"seed": seed},
|
| 371 |
+
)
|
| 372 |
+
reset_obs = reset_resp.get("observation", reset_resp)
|
| 373 |
+
|
| 374 |
+
messages: list[ChatCompletionMessageParam] = [
|
| 375 |
+
{"role": "system", "content": _system_prompt_for_task(task_id)},
|
| 376 |
+
{
|
| 377 |
+
"role": "user",
|
| 378 |
+
"content": TASK_PROMPTS[task_id]
|
| 379 |
+
+ f"\n\nEnvironment says: {reset_obs['message']}",
|
| 380 |
+
},
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
done = False
|
| 384 |
+
step_num = 0
|
| 385 |
+
no_tool_streak = 0
|
| 386 |
+
for turn in range(1, max_turns + 1):
|
| 387 |
+
try:
|
| 388 |
+
response = _create_chat_completion(
|
| 389 |
+
client,
|
| 390 |
+
messages,
|
| 391 |
+
tools,
|
| 392 |
+
task_id=task_id,
|
| 393 |
+
env_seed=seed,
|
| 394 |
+
)
|
| 395 |
+
except BadRequestError as e:
|
| 396 |
+
err_str = str(e).lower()
|
| 397 |
+
if "tool" not in err_str and "function" not in err_str:
|
| 398 |
+
raise
|
| 399 |
+
if messages and messages[-1].get("role") == "assistant": # type: ignore[union-attr]
|
| 400 |
+
messages.pop()
|
| 401 |
+
messages.append(
|
| 402 |
+
{
|
| 403 |
+
"role": "user",
|
| 404 |
+
"content": (
|
| 405 |
+
"IMPORTANT: Call tools using ONLY these exact names: "
|
| 406 |
+
f"{names_csv}. "
|
| 407 |
+
"Put ALL parameters inside the tool's JSON arguments field. "
|
| 408 |
+
"Do NOT embed parameters in the tool name itself."
|
| 409 |
+
),
|
| 410 |
+
}
|
| 411 |
+
)
|
| 412 |
+
try:
|
| 413 |
+
response = _create_chat_completion(
|
| 414 |
+
client,
|
| 415 |
+
messages,
|
| 416 |
+
tools,
|
| 417 |
+
task_id=task_id,
|
| 418 |
+
env_seed=seed,
|
| 419 |
+
)
|
| 420 |
+
except BadRequestError:
|
| 421 |
+
break
|
| 422 |
+
msg = response.choices[0].message
|
| 423 |
+
|
| 424 |
+
if not msg.tool_calls:
|
| 425 |
+
no_tool_streak += 1
|
| 426 |
+
if no_tool_streak > 3:
|
| 427 |
+
break
|
| 428 |
+
messages.append(msg) # type: ignore[arg-type]
|
| 429 |
+
messages.append(
|
| 430 |
+
{
|
| 431 |
+
"role": "user",
|
| 432 |
+
"content": (
|
| 433 |
+
f"You must respond with exactly one tool call ({names_csv}). "
|
| 434 |
+
"Do not reply with plain text only."
|
| 435 |
+
),
|
| 436 |
+
}
|
| 437 |
+
)
|
| 438 |
+
continue
|
| 439 |
+
no_tool_streak = 0
|
| 440 |
+
|
| 441 |
+
messages.append(msg) # type: ignore[arg-type]
|
| 442 |
+
|
| 443 |
+
for tc in msg.tool_calls:
|
| 444 |
+
try:
|
| 445 |
+
action_payload = _action_from_tool_call(tc)
|
| 446 |
+
except (json.JSONDecodeError, ValidationError, ValueError) as e:
|
| 447 |
+
messages.append(
|
| 448 |
+
{
|
| 449 |
+
"role": "tool",
|
| 450 |
+
"tool_call_id": tc.id,
|
| 451 |
+
"content": f"Invalid tool arguments: {e}",
|
| 452 |
+
}
|
| 453 |
+
)
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
step_num += 1
|
| 457 |
+
step_resp = _request_json(
|
| 458 |
+
http,
|
| 459 |
+
"POST",
|
| 460 |
+
"/step",
|
| 461 |
+
timeout=30,
|
| 462 |
+
json={"action": action_payload},
|
| 463 |
+
)
|
| 464 |
+
obs = step_resp.get("observation", step_resp)
|
| 465 |
+
reward_raw = step_resp.get("reward")
|
| 466 |
+
reward = 0.0 if reward_raw is None else float(reward_raw)
|
| 467 |
+
done = step_resp.get("done", False)
|
| 468 |
+
|
| 469 |
+
rewards.append(reward)
|
| 470 |
+
steps_taken = step_num
|
| 471 |
+
err = _obs_error(obs if isinstance(obs, dict) else {})
|
| 472 |
+
log_step(
|
| 473 |
+
step=step_num,
|
| 474 |
+
action=_action_str(action_payload),
|
| 475 |
+
reward=reward,
|
| 476 |
+
done=done,
|
| 477 |
+
error=err,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
messages.append(
|
| 481 |
+
{"role": "tool", "tool_call_id": tc.id, "content": json.dumps(obs)}
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if done:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
if done:
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
grade = _request_json(http, "GET", f"/grader/{task_id}", timeout=10)
|
| 491 |
+
_emit_grader_details_to_stderr(grade)
|
| 492 |
+
score = float(grade["score"])
|
| 493 |
+
score = min(max(score, 0.0), 1.0)
|
| 494 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 495 |
+
except Exception as exc:
|
| 496 |
+
print(
|
| 497 |
+
f"[inference] task={task_id} failed: {exc!r}", file=sys.stderr, flush=True
|
| 498 |
+
)
|
| 499 |
+
finally:
|
| 500 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 501 |
+
|
| 502 |
+
return score
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def _parse_args() -> argparse.Namespace:
|
| 506 |
+
p = argparse.ArgumentParser(
|
| 507 |
+
description="DataOpsEnv inference (OpenAI client; protocol lines on stdout)."
|
| 508 |
+
)
|
| 509 |
+
p.add_argument(
|
| 510 |
+
"--task",
|
| 511 |
+
action="append",
|
| 512 |
+
choices=TASK_IDS,
|
| 513 |
+
dest="tasks",
|
| 514 |
+
help="Run only the selected task(s). Defaults to all tasks.",
|
| 515 |
+
)
|
| 516 |
+
p.add_argument(
|
| 517 |
+
"--seed",
|
| 518 |
+
type=int,
|
| 519 |
+
default=None,
|
| 520 |
+
help="Environment seed for /reset; also used for LLM seed when the API supports it.",
|
| 521 |
+
)
|
| 522 |
+
p.add_argument(
|
| 523 |
+
"--max-turns",
|
| 524 |
+
type=int,
|
| 525 |
+
default=MAX_TURNS,
|
| 526 |
+
help=f"Maximum tool-using turns per task (default: {MAX_TURNS}).",
|
| 527 |
+
)
|
| 528 |
+
p.add_argument(
|
| 529 |
+
"--json-scores",
|
| 530 |
+
action="store_true",
|
| 531 |
+
help="Print a final JSON object with scores to stdout (for POST /baseline).",
|
| 532 |
+
)
|
| 533 |
+
return p.parse_args()
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _run_inference_sync(args: argparse.Namespace) -> None:
|
| 537 |
+
client = _openai_client()
|
| 538 |
+
scores: dict[str, float] = {}
|
| 539 |
+
grades: dict[str, dict[str, Any]] = {}
|
| 540 |
+
task_ids = args.tasks or list(TASK_PROMPTS)
|
| 541 |
+
with requests.Session() as http:
|
| 542 |
+
for task_id in task_ids:
|
| 543 |
+
scores[task_id] = run_task(
|
| 544 |
+
client,
|
| 545 |
+
http,
|
| 546 |
+
task_id,
|
| 547 |
+
max_turns=max(1, int(args.max_turns)),
|
| 548 |
+
seed=args.seed,
|
| 549 |
+
)
|
| 550 |
+
if args.json_scores:
|
| 551 |
+
try:
|
| 552 |
+
grades[task_id] = _request_json(
|
| 553 |
+
http,
|
| 554 |
+
"GET",
|
| 555 |
+
f"/grader/{task_id}",
|
| 556 |
+
timeout=10,
|
| 557 |
+
)
|
| 558 |
+
except Exception:
|
| 559 |
+
grades[task_id] = {
|
| 560 |
+
"task_id": task_id,
|
| 561 |
+
"score": scores[task_id],
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
if args.json_scores:
|
| 565 |
+
avg = sum(scores.values()) / len(scores)
|
| 566 |
+
payload = {
|
| 567 |
+
"scores": scores,
|
| 568 |
+
"grades": grades,
|
| 569 |
+
"average": round(avg, 4),
|
| 570 |
+
"model": MODEL_NAME,
|
| 571 |
+
"metadata": {
|
| 572 |
+
"env_base_url": ENV_BASE_URL,
|
| 573 |
+
"seed": args.seed,
|
| 574 |
+
"max_turns": max(1, int(args.max_turns)),
|
| 575 |
+
"tasks": task_ids,
|
| 576 |
+
"generated_at_utc": datetime.now(timezone.utc).isoformat(),
|
| 577 |
+
"model_base_url": str(getattr(client, "base_url", "")),
|
| 578 |
+
},
|
| 579 |
+
}
|
| 580 |
+
print(json.dumps(payload), flush=True)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
async def main() -> None:
|
| 584 |
+
args = _parse_args()
|
| 585 |
+
await asyncio.to_thread(_run_inference_sync, args)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if __name__ == "__main__":
|
| 589 |
+
asyncio.run(main())
|
models.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
from openenv.core.env_server import (
|
| 5 |
+
Action as BaseAction,
|
| 6 |
+
)
|
| 7 |
+
from openenv.core.env_server import (
|
| 8 |
+
Observation as BaseObservation,
|
| 9 |
+
)
|
| 10 |
+
from openenv.core.env_server import (
|
| 11 |
+
State as BaseState,
|
| 12 |
+
)
|
| 13 |
+
from pydantic import BaseModel, Field, field_validator
|
| 14 |
+
|
| 15 |
+
# ── Action Payload Models (Pydantic-validated) ─────────────────────
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ExecuteSQLPayload(BaseModel):
|
| 19 |
+
query: str = Field(..., min_length=1, max_length=2000)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ReadFilePayload(BaseModel):
|
| 23 |
+
filepath: str = Field(..., min_length=1, max_length=255)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class WriteFilePayload(BaseModel):
|
| 27 |
+
filepath: str = Field(..., min_length=1, max_length=255)
|
| 28 |
+
content: str = Field(..., max_length=1_000_000)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RunScriptPayload(BaseModel):
|
| 32 |
+
filepath: str = Field(..., min_length=1, max_length=255)
|
| 33 |
+
args: List[str] = Field(default_factory=list, max_length=20)
|
| 34 |
+
|
| 35 |
+
@field_validator("filepath")
|
| 36 |
+
@classmethod
|
| 37 |
+
def must_be_safe_script_name(cls, v: str) -> str:
|
| 38 |
+
basename = v.rsplit("/", 1)[-1]
|
| 39 |
+
if not re.match(r"^[a-zA-Z0-9_\-]+\.py$", basename):
|
| 40 |
+
raise ValueError("Script name must be alphanumeric with .py extension.")
|
| 41 |
+
return v
|
| 42 |
+
|
| 43 |
+
@field_validator("args")
|
| 44 |
+
@classmethod
|
| 45 |
+
def args_must_be_safe(cls, v: list[str]) -> list[str]:
|
| 46 |
+
for arg in v:
|
| 47 |
+
if not isinstance(arg, str) or len(arg) > 500:
|
| 48 |
+
raise ValueError("Each arg must be a string under 500 chars.")
|
| 49 |
+
return v
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SendEmailPayload(BaseModel):
|
| 53 |
+
to_email: str = Field(..., max_length=320)
|
| 54 |
+
subject: str = Field(..., min_length=1, max_length=500)
|
| 55 |
+
body: str = Field(..., min_length=1, max_length=100_000)
|
| 56 |
+
|
| 57 |
+
@field_validator("to_email")
|
| 58 |
+
@classmethod
|
| 59 |
+
def must_look_like_email(cls, v: str) -> str:
|
| 60 |
+
if not re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", v):
|
| 61 |
+
raise ValueError("Invalid email format.")
|
| 62 |
+
return v
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
ACTION_TYPE = Literal["ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail"]
|
| 66 |
+
|
| 67 |
+
PAYLOAD_MODELS: dict[str, type[BaseModel]] = {
|
| 68 |
+
"ExecuteSQL": ExecuteSQLPayload,
|
| 69 |
+
"ReadFile": ReadFilePayload,
|
| 70 |
+
"WriteFile": WriteFilePayload,
|
| 71 |
+
"RunScript": RunScriptPayload,
|
| 72 |
+
"SendEmail": SendEmailPayload,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ── Action Model (extends OpenEnv Action) ──────────────────────────
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DataOpsAction(BaseAction):
|
| 80 |
+
action_type: ACTION_TYPE = Field(
|
| 81 |
+
..., description="One of: ExecuteSQL, ReadFile, WriteFile, RunScript, SendEmail"
|
| 82 |
+
)
|
| 83 |
+
payload: Dict[str, Any] = Field(
|
| 84 |
+
..., description="Parameters for the chosen action type."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ── Observation Model (extends OpenEnv Observation) ────────────────
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class DataOpsObservation(BaseObservation):
|
| 92 |
+
status: Literal["success", "error"] = "error"
|
| 93 |
+
message: str = ""
|
| 94 |
+
stdout: Optional[str] = None
|
| 95 |
+
stderr: Optional[str] = None
|
| 96 |
+
sql_results: Optional[List[Dict[str, Any]]] = None
|
| 97 |
+
email_delivery_status: Optional[str] = None
|
| 98 |
+
step_count: int = 0
|
| 99 |
+
max_steps: int = 0
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ── State Model (extends OpenEnv State) ────────────────────────────
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DataOpsState(BaseState):
|
| 106 |
+
task_id: str = ""
|
| 107 |
+
task_description: str = ""
|
| 108 |
+
seed: int = 0
|
| 109 |
+
max_steps: int = 15
|
| 110 |
+
done: bool = False
|
| 111 |
+
cumulative_reward: float = 0.0
|
| 112 |
+
actions_taken: List[str] = Field(default_factory=list)
|
| 113 |
+
emails_sent: int = 0
|
openenv.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: dataops_env
|
| 3 |
+
version: 1.0.0
|
| 4 |
+
description: Seeded enterprise DataOps benchmark with isolated sessions, deterministic graders, and three escalating remediation tasks.
|
| 5 |
+
type: space
|
| 6 |
+
runtime: fastapi
|
| 7 |
+
app: server.app:app
|
| 8 |
+
port: 7860
|
| 9 |
+
tasks:
|
| 10 |
+
- id: task_1_easy_anomaly
|
| 11 |
+
name: Delete Corrupted Transaction Rows
|
| 12 |
+
difficulty: easy
|
| 13 |
+
description: Inspect a seeded transaction table and remove only the rows whose amount is NULL, preserving legitimate non-null edge values.
|
| 14 |
+
benchmark_focus: Careful data cleanup without collateral damage.
|
| 15 |
+
allowed_actions:
|
| 16 |
+
- ExecuteSQL
|
| 17 |
+
- id: task_2_medium_syntax
|
| 18 |
+
name: Repair Seeded Pipeline Script
|
| 19 |
+
difficulty: medium
|
| 20 |
+
description: Repair a seeded ETL normalization script and verify it against visible and hidden seeded batches.
|
| 21 |
+
benchmark_focus: Code reading, precise repair, and generalization beyond the demo batch.
|
| 22 |
+
allowed_actions:
|
| 23 |
+
- ReadFile
|
| 24 |
+
- WriteFile
|
| 25 |
+
- RunScript
|
| 26 |
+
- id: task_3_hard_e2e
|
| 27 |
+
name: Resolve Revenue Reporting Incident
|
| 28 |
+
difficulty: hard
|
| 29 |
+
description: Extract a seeded reporting slice, repair the formatter, and send the exact generated report.
|
| 30 |
+
benchmark_focus: End-to-end data extraction, file repair, and communication with provenance.
|
| 31 |
+
allowed_actions:
|
| 32 |
+
- ExecuteSQL
|
| 33 |
+
- ReadFile
|
| 34 |
+
- WriteFile
|
| 35 |
+
- RunScript
|
| 36 |
+
- SendEmail
|
pyproject.toml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-dataops_env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant)."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.12"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"openenv-core[core]>=0.2.2",
|
| 13 |
+
"fastapi>=0.115.0",
|
| 14 |
+
"starlette>=0.46.0,<0.52.0",
|
| 15 |
+
"uvicorn[standard]>=0.34.0",
|
| 16 |
+
"pydantic>=2.10.0",
|
| 17 |
+
"pyyaml>=6.0.2",
|
| 18 |
+
"openai>=1.60.0",
|
| 19 |
+
"requests>=2.32.0",
|
| 20 |
+
"wsproto>=1.3.2",
|
| 21 |
+
"python-dotenv>=1.0.0",
|
| 22 |
+
"huggingface-hub>=1.8.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = [
|
| 27 |
+
"pytest>=8.0.0",
|
| 28 |
+
"pytest-cov>=4.0.0",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.scripts]
|
| 32 |
+
server = "dataops_env.server.app:main"
|
| 33 |
+
|
| 34 |
+
[tool.setuptools]
|
| 35 |
+
include-package-data = true
|
| 36 |
+
packages = ["dataops_env", "dataops_env.server"]
|
| 37 |
+
package-dir = { "dataops_env" = ".", "dataops_env.server" = "server" }
|
| 38 |
+
|
| 39 |
+
[tool.pytest.ini_options]
|
| 40 |
+
pythonpath = ["."]
|
| 41 |
+
testpaths = ["tests"]
|
| 42 |
+
addopts = "--import-mode=importlib"
|
server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataOps environment server components."""
|
| 2 |
+
|
| 3 |
+
from .dataops_env_environment import DataOpsEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["DataOpsEnvironment"]
|
server/__main__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the API server from the repo root: ``uv run python -m server``."""
|
| 2 |
+
|
| 3 |
+
from server.app import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
server/app.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
from collections.abc import AsyncIterator
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
import uvicorn
|
| 13 |
+
from fastapi import (
|
| 14 |
+
Body,
|
| 15 |
+
FastAPI,
|
| 16 |
+
HTTPException,
|
| 17 |
+
Query,
|
| 18 |
+
Request,
|
| 19 |
+
Response,
|
| 20 |
+
WebSocket,
|
| 21 |
+
WebSocketDisconnect,
|
| 22 |
+
)
|
| 23 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 24 |
+
from openenv.core.env_server.http_server import serialize_observation
|
| 25 |
+
from openenv.core.env_server.types import (
|
| 26 |
+
HealthResponse,
|
| 27 |
+
HealthStatus,
|
| 28 |
+
ResetRequest,
|
| 29 |
+
ResetResponse,
|
| 30 |
+
StepRequest,
|
| 31 |
+
StepResponse,
|
| 32 |
+
WSCloseMessage,
|
| 33 |
+
WSErrorCode,
|
| 34 |
+
WSErrorResponse,
|
| 35 |
+
WSObservationResponse,
|
| 36 |
+
WSResetMessage,
|
| 37 |
+
WSStateMessage,
|
| 38 |
+
WSStateResponse,
|
| 39 |
+
WSStepMessage,
|
| 40 |
+
)
|
| 41 |
+
from pydantic import ValidationError
|
| 42 |
+
|
| 43 |
+
from env_loader import load_env
|
| 44 |
+
from models import DataOpsAction, DataOpsObservation, DataOpsState
|
| 45 |
+
from server.dataops_env_environment import DataOpsEnvironment
|
| 46 |
+
from server.grading import evaluate_task
|
| 47 |
+
from server.session_manager import EnvironmentSessionManager
|
| 48 |
+
from server.task_specs import TASK_IDS, task_manifest_entries
|
| 49 |
+
|
| 50 |
+
# Repo root must be on sys.path (e.g. run `uv run python -m server.app` or uvicorn from project root).
|
| 51 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 52 |
+
SERVER_DIR = Path(__file__).resolve().parent
|
| 53 |
+
|
| 54 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(name)s | %(message)s")
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
load_env()
|
| 58 |
+
|
| 59 |
+
SESSION_COOKIE_NAME = "dataops_session_id"
|
| 60 |
+
SESSION_HEADER_NAME = "X-Session-ID"
|
| 61 |
+
MAX_HTTP_SESSIONS = int(os.getenv("MAX_HTTP_SESSIONS", "128"))
|
| 62 |
+
HTTP_SESSION_TIMEOUT_S = float(os.getenv("HTTP_SESSION_TIMEOUT_S", "1200"))
|
| 63 |
+
MAX_WS_SESSIONS = max(1, int(os.getenv("MAX_WS_SESSIONS", "64")))
|
| 64 |
+
ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "").strip()
|
| 65 |
+
COOKIE_SECURE = os.getenv("COOKIE_SECURE", "").lower() in {"1", "true", "yes"}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _public_grader_details_enabled() -> bool:
|
| 69 |
+
"""Read at request time so env / tests can control visibility without stale import-time state."""
|
| 70 |
+
v = os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower()
|
| 71 |
+
return v in {"1", "true", "yes"}
|
| 72 |
+
|
| 73 |
+
_ws_active_sessions = 0
|
| 74 |
+
_ws_session_lock = asyncio.Lock()
|
| 75 |
+
|
| 76 |
+
session_manager = EnvironmentSessionManager(
|
| 77 |
+
max_sessions=MAX_HTTP_SESSIONS,
|
| 78 |
+
session_timeout_s=HTTP_SESSION_TIMEOUT_S,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@asynccontextmanager
|
| 83 |
+
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
| 84 |
+
logger.info("DataOpsEnv starting.")
|
| 85 |
+
yield
|
| 86 |
+
session_manager.close_all()
|
| 87 |
+
logger.info("DataOpsEnv shutting down.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
app = FastAPI(
|
| 91 |
+
title="DataOpsEnv",
|
| 92 |
+
description="Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant).",
|
| 93 |
+
version="1.0.0",
|
| 94 |
+
lifespan=lifespan,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _cors_allow_origins() -> list[str]:
|
| 99 |
+
configured = os.getenv("CORS_ALLOW_ORIGINS", "").strip()
|
| 100 |
+
|
| 101 |
+
if not configured:
|
| 102 |
+
return []
|
| 103 |
+
|
| 104 |
+
if configured == "*":
|
| 105 |
+
return ["*"]
|
| 106 |
+
|
| 107 |
+
return [item.strip() for item in configured.split(",") if item.strip()]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
app.add_middleware(
|
| 111 |
+
CORSMiddleware,
|
| 112 |
+
allow_origins=_cors_allow_origins(),
|
| 113 |
+
allow_methods=["*"],
|
| 114 |
+
allow_headers=["*"],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _load_manifest() -> dict:
|
| 119 |
+
yaml_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "openenv.yaml")
|
| 120 |
+
try:
|
| 121 |
+
with open(yaml_path, encoding="utf-8") as f:
|
| 122 |
+
return yaml.safe_load(f) or {}
|
| 123 |
+
except FileNotFoundError:
|
| 124 |
+
return {}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _load_yaml_tasks() -> list[dict]:
|
| 128 |
+
manifest = _load_manifest()
|
| 129 |
+
tasks = manifest.get("tasks")
|
| 130 |
+
if isinstance(tasks, list) and tasks:
|
| 131 |
+
manifest_ids = [str(item.get("id", "")) for item in tasks]
|
| 132 |
+
if manifest_ids == TASK_IDS:
|
| 133 |
+
return tasks
|
| 134 |
+
return task_manifest_entries()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _wrap_obs(obs: DataOpsObservation) -> dict:
|
| 138 |
+
"""Serialise an observation to the standard OpenEnv response dict."""
|
| 139 |
+
return obs.model_dump()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _get_session_id(request: Request) -> str | None:
|
| 143 |
+
header_value = request.headers.get(SESSION_HEADER_NAME)
|
| 144 |
+
if header_value:
|
| 145 |
+
return header_value.strip() or None
|
| 146 |
+
cookie_value = request.cookies.get(SESSION_COOKIE_NAME)
|
| 147 |
+
if cookie_value:
|
| 148 |
+
return cookie_value.strip() or None
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _attach_session(response: Response, session_id: str) -> None:
|
| 153 |
+
response.set_cookie(
|
| 154 |
+
key=SESSION_COOKIE_NAME,
|
| 155 |
+
value=session_id,
|
| 156 |
+
httponly=True,
|
| 157 |
+
samesite="lax",
|
| 158 |
+
secure=COOKIE_SECURE,
|
| 159 |
+
max_age=int(HTTP_SESSION_TIMEOUT_S),
|
| 160 |
+
)
|
| 161 |
+
response.headers[SESSION_HEADER_NAME] = session_id
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _require_active_env(request: Request) -> tuple[str, DataOpsEnvironment]:
|
| 165 |
+
session_id, env = session_manager.get_session(_get_session_id(request))
|
| 166 |
+
if session_id is None or env is None:
|
| 167 |
+
raise HTTPException(400, "No active episode. Call /reset first.")
|
| 168 |
+
return session_id, env
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _ws_error_payload(message: str, code: WSErrorCode) -> str:
|
| 172 |
+
return WSErrorResponse(
|
| 173 |
+
data={
|
| 174 |
+
"message": message,
|
| 175 |
+
"code": code.value,
|
| 176 |
+
}
|
| 177 |
+
).model_dump_json()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _require_admin(request: Request) -> None:
|
| 181 |
+
if not ADMIN_API_KEY:
|
| 182 |
+
return
|
| 183 |
+
if request.headers.get("X-Admin-Key", "") != ADMIN_API_KEY:
|
| 184 |
+
raise HTTPException(403, "Missing or invalid admin key.")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _request_is_admin(request: Request) -> bool:
|
| 188 |
+
return bool(ADMIN_API_KEY) and request.headers.get("X-Admin-Key", "") == ADMIN_API_KEY
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _format_grader_response(grade: dict, request: Request) -> dict:
|
| 192 |
+
if _public_grader_details_enabled() or _request_is_admin(request):
|
| 193 |
+
return grade
|
| 194 |
+
return {"task_id": grade.get("task_id"), "score": grade.get("score")}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
async def _try_acquire_ws_slot() -> bool:
|
| 198 |
+
global _ws_active_sessions
|
| 199 |
+
|
| 200 |
+
async with _ws_session_lock:
|
| 201 |
+
if _ws_active_sessions >= MAX_WS_SESSIONS:
|
| 202 |
+
return False
|
| 203 |
+
_ws_active_sessions += 1
|
| 204 |
+
return True
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
async def _release_ws_slot() -> None:
|
| 208 |
+
global _ws_active_sessions
|
| 209 |
+
|
| 210 |
+
async with _ws_session_lock:
|
| 211 |
+
_ws_active_sessions = max(0, _ws_active_sessions - 1)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@app.get("/health", response_model=HealthResponse)
|
| 215 |
+
def health_endpoint():
|
| 216 |
+
return HealthResponse(status=HealthStatus.HEALTHY)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@app.get("/metadata")
|
| 220 |
+
def metadata_endpoint():
|
| 221 |
+
manifest = _load_manifest()
|
| 222 |
+
return {
|
| 223 |
+
"name": manifest.get("name", "dataops_env"),
|
| 224 |
+
"description": manifest.get(
|
| 225 |
+
"description",
|
| 226 |
+
(
|
| 227 |
+
"Enterprise data pipeline remediation environment. "
|
| 228 |
+
"Agents debug data streams, fix scripts, and send email reports."
|
| 229 |
+
),
|
| 230 |
+
),
|
| 231 |
+
"version": manifest.get("version", "1.0.0"),
|
| 232 |
+
"task_count": len(_load_yaml_tasks()),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@app.get("/schema")
|
| 237 |
+
def schema_endpoint():
|
| 238 |
+
return {
|
| 239 |
+
"action": DataOpsAction.model_json_schema(),
|
| 240 |
+
"observation": DataOpsObservation.model_json_schema(),
|
| 241 |
+
"state": DataOpsState.model_json_schema(),
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@app.post("/mcp")
|
| 246 |
+
def mcp_endpoint(body: dict = Body(default_factory=dict)):
|
| 247 |
+
method = body.get("method", "")
|
| 248 |
+
req_id = body.get("id")
|
| 249 |
+
|
| 250 |
+
if method == "tools/list":
|
| 251 |
+
tools = [
|
| 252 |
+
{"name": atype, "description": f"Execute a {atype} action."}
|
| 253 |
+
for atype in [
|
| 254 |
+
"ExecuteSQL",
|
| 255 |
+
"ReadFile",
|
| 256 |
+
"WriteFile",
|
| 257 |
+
"RunScript",
|
| 258 |
+
"SendEmail",
|
| 259 |
+
]
|
| 260 |
+
]
|
| 261 |
+
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}}
|
| 262 |
+
|
| 263 |
+
return {
|
| 264 |
+
"jsonrpc": "2.0",
|
| 265 |
+
"id": req_id,
|
| 266 |
+
"error": {"code": -32601, "message": "Method not found"},
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@app.websocket("/ws")
|
| 271 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 272 |
+
await websocket.accept()
|
| 273 |
+
acquired_slot = await _try_acquire_ws_slot()
|
| 274 |
+
if not acquired_slot:
|
| 275 |
+
await websocket.send_text(
|
| 276 |
+
_ws_error_payload(
|
| 277 |
+
"WebSocket session capacity reached.",
|
| 278 |
+
WSErrorCode.CAPACITY_REACHED,
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
await websocket.close(code=1013)
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
env = DataOpsEnvironment()
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
while True:
|
| 288 |
+
raw_message = await websocket.receive_text()
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
message_dict = json.loads(raw_message)
|
| 292 |
+
except json.JSONDecodeError:
|
| 293 |
+
await websocket.send_text(
|
| 294 |
+
_ws_error_payload("Invalid JSON payload.", WSErrorCode.INVALID_JSON)
|
| 295 |
+
)
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
message_type = message_dict.get("type", "")
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
if message_type == "reset":
|
| 302 |
+
message = WSResetMessage(**message_dict)
|
| 303 |
+
observation = env.reset(**message.data)
|
| 304 |
+
response = WSObservationResponse(
|
| 305 |
+
data=serialize_observation(observation)
|
| 306 |
+
)
|
| 307 |
+
elif message_type == "step":
|
| 308 |
+
message = WSStepMessage(**message_dict)
|
| 309 |
+
action = DataOpsAction(**message.data)
|
| 310 |
+
observation = env.step(action)
|
| 311 |
+
response = WSObservationResponse(
|
| 312 |
+
data=serialize_observation(observation)
|
| 313 |
+
)
|
| 314 |
+
elif message_type == "state":
|
| 315 |
+
WSStateMessage(**message_dict)
|
| 316 |
+
response = WSStateResponse(data=env.state.model_dump())
|
| 317 |
+
elif message_type == "close":
|
| 318 |
+
WSCloseMessage(**message_dict)
|
| 319 |
+
break
|
| 320 |
+
else:
|
| 321 |
+
await websocket.send_text(
|
| 322 |
+
_ws_error_payload(
|
| 323 |
+
f"Unknown message type: {message_type}",
|
| 324 |
+
WSErrorCode.UNKNOWN_TYPE,
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
await websocket.send_text(response.model_dump_json())
|
| 330 |
+
except ValidationError:
|
| 331 |
+
await websocket.send_text(
|
| 332 |
+
_ws_error_payload(
|
| 333 |
+
"Validation error while handling the WebSocket message.",
|
| 334 |
+
WSErrorCode.VALIDATION_ERROR,
|
| 335 |
+
)
|
| 336 |
+
)
|
| 337 |
+
except Exception:
|
| 338 |
+
logger.exception("WebSocket execution error")
|
| 339 |
+
await websocket.send_text(
|
| 340 |
+
_ws_error_payload(
|
| 341 |
+
"Execution error while handling the WebSocket message.",
|
| 342 |
+
WSErrorCode.EXECUTION_ERROR,
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
except WebSocketDisconnect:
|
| 346 |
+
logger.debug("WebSocket client disconnected.")
|
| 347 |
+
finally:
|
| 348 |
+
env.close()
|
| 349 |
+
await _release_ws_slot()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@app.post("/reset", response_model=ResetResponse)
|
| 353 |
+
def reset_endpoint(
|
| 354 |
+
request: Request,
|
| 355 |
+
response: Response,
|
| 356 |
+
task_id: str = Query("task_1_easy_anomaly", description="Task to initialise."),
|
| 357 |
+
body: ResetRequest = Body(default_factory=ResetRequest),
|
| 358 |
+
):
|
| 359 |
+
if task_id not in TASK_IDS:
|
| 360 |
+
raise HTTPException(400, f"Invalid task_id. Choose from: {TASK_IDS}")
|
| 361 |
+
session_id = _get_session_id(request)
|
| 362 |
+
resolved_session_id, _env, obs = session_manager.reset_session(
|
| 363 |
+
task_id=task_id,
|
| 364 |
+
seed=body.seed,
|
| 365 |
+
episode_id=body.episode_id,
|
| 366 |
+
session_id=session_id,
|
| 367 |
+
)
|
| 368 |
+
_attach_session(response, resolved_session_id)
|
| 369 |
+
return ResetResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@app.post("/step", response_model=StepResponse)
|
| 373 |
+
def step_endpoint(request: Request, response: Response, body: StepRequest):
|
| 374 |
+
try:
|
| 375 |
+
action = DataOpsAction(**body.action)
|
| 376 |
+
except ValidationError as e:
|
| 377 |
+
raise HTTPException(422, f"Invalid action: {e}") from e
|
| 378 |
+
session_id, env = _require_active_env(request)
|
| 379 |
+
_attach_session(response, session_id)
|
| 380 |
+
obs = env.step(action, timeout_s=body.timeout_s)
|
| 381 |
+
return StepResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@app.get("/state", response_model=DataOpsState)
|
| 385 |
+
def state_endpoint(request: Request, response: Response):
|
| 386 |
+
session_id, env = _require_active_env(request)
|
| 387 |
+
_attach_session(response, session_id)
|
| 388 |
+
return env.state
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@app.get("/tasks")
|
| 392 |
+
def tasks_endpoint():
|
| 393 |
+
return {
|
| 394 |
+
"tasks": _load_yaml_tasks(),
|
| 395 |
+
"action_schema": DataOpsAction.model_json_schema(),
|
| 396 |
+
"observation_schema": DataOpsObservation.model_json_schema(),
|
| 397 |
+
"state_schema": DataOpsState.model_json_schema(),
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@app.get("/grader")
|
| 402 |
+
def grader_current_endpoint(request: Request, response: Response):
|
| 403 |
+
"""Grade the current episode (uses active task_id from state)."""
|
| 404 |
+
session_id, env = _require_active_env(request)
|
| 405 |
+
_attach_session(response, session_id)
|
| 406 |
+
task_id = env.state.task_id
|
| 407 |
+
if not task_id:
|
| 408 |
+
raise HTTPException(400, "No active episode. Call /reset first.")
|
| 409 |
+
return _format_grader_response(evaluate_task(task_id, env), request)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@app.get("/grader/{task_id}")
|
| 413 |
+
def grader_endpoint(task_id: str, request: Request, response: Response):
|
| 414 |
+
if task_id not in TASK_IDS:
|
| 415 |
+
raise HTTPException(404, f"Unknown task: {task_id}")
|
| 416 |
+
session_id, env = _require_active_env(request)
|
| 417 |
+
_attach_session(response, session_id)
|
| 418 |
+
active_task_id = env.state.task_id
|
| 419 |
+
if active_task_id and active_task_id != task_id:
|
| 420 |
+
raise HTTPException(
|
| 421 |
+
400,
|
| 422 |
+
f"Active episode belongs to task '{active_task_id}'. Reset the requested task first.",
|
| 423 |
+
)
|
| 424 |
+
return _format_grader_response(evaluate_task(task_id, env), request)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@app.post("/baseline")
|
| 428 |
+
def baseline_endpoint(request: Request, body: dict = Body(default_factory=dict)):
|
| 429 |
+
"""Run inference.py (OpenAI tool-calling agent) against all tasks; same entrypoint as local baseline."""
|
| 430 |
+
_require_admin(request)
|
| 431 |
+
if not (
|
| 432 |
+
os.environ.get("API_KEY", "").strip()
|
| 433 |
+
or os.environ.get("HF_TOKEN", "").strip()
|
| 434 |
+
):
|
| 435 |
+
raise HTTPException(
|
| 436 |
+
503,
|
| 437 |
+
"API_KEY or HF_TOKEN must be set on the server process to run POST /baseline.",
|
| 438 |
+
)
|
| 439 |
+
script_path = PROJECT_ROOT / "inference.py"
|
| 440 |
+
if not script_path.is_file():
|
| 441 |
+
raise HTTPException(500, "inference.py missing from project root.")
|
| 442 |
+
|
| 443 |
+
port = int(os.getenv("PORT", "7860"))
|
| 444 |
+
timeout_s = HTTP_SESSION_TIMEOUT_S
|
| 445 |
+
|
| 446 |
+
env = {
|
| 447 |
+
**os.environ,
|
| 448 |
+
"ENV_BASE_URL": os.getenv("ENV_BASE_URL", f"http://127.0.0.1:{port}"),
|
| 449 |
+
}
|
| 450 |
+
command = [sys.executable, str(script_path), "--json-scores"]
|
| 451 |
+
if body.get("seed") is not None:
|
| 452 |
+
command.extend(["--seed", str(int(body["seed"]))])
|
| 453 |
+
if body.get("max_turns") is not None:
|
| 454 |
+
command.extend(["--max-turns", str(int(body["max_turns"]))])
|
| 455 |
+
for task_id in body.get("task_ids", []) or []:
|
| 456 |
+
if task_id in TASK_IDS:
|
| 457 |
+
command.extend(["--task", str(task_id)])
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
proc = subprocess.run(
|
| 461 |
+
command,
|
| 462 |
+
cwd=str(PROJECT_ROOT),
|
| 463 |
+
capture_output=True,
|
| 464 |
+
text=True,
|
| 465 |
+
timeout=timeout_s,
|
| 466 |
+
env=env,
|
| 467 |
+
)
|
| 468 |
+
except subprocess.TimeoutExpired:
|
| 469 |
+
raise HTTPException(
|
| 470 |
+
504, f"Baseline exceeded HTTP_SESSION_TIMEOUT_S ({timeout_s}s)."
|
| 471 |
+
) from None
|
| 472 |
+
|
| 473 |
+
if proc.returncode != 0:
|
| 474 |
+
tail = (proc.stderr or proc.stdout or "")[-6000:]
|
| 475 |
+
logger.error("inference.py failed rc=%s stderr=%s", proc.returncode, tail[:500])
|
| 476 |
+
raise HTTPException(
|
| 477 |
+
502,
|
| 478 |
+
{"message": "inference.py exited with an error.", "detail": tail},
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
lines = [ln.strip() for ln in (proc.stdout or "").splitlines() if ln.strip()]
|
| 482 |
+
parsed = None
|
| 483 |
+
for line in reversed(lines):
|
| 484 |
+
try:
|
| 485 |
+
parsed = json.loads(line)
|
| 486 |
+
break
|
| 487 |
+
except json.JSONDecodeError:
|
| 488 |
+
continue
|
| 489 |
+
if not isinstance(parsed, dict) or "scores" not in parsed:
|
| 490 |
+
raise HTTPException(
|
| 491 |
+
502,
|
| 492 |
+
{
|
| 493 |
+
"message": "Could not parse JSON scores from inference.py stdout.",
|
| 494 |
+
"stdout_tail": "\n".join(lines[-5:]),
|
| 495 |
+
},
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
return {
|
| 499 |
+
"message": "Model baseline completed via inference.py.",
|
| 500 |
+
"stdout": proc.stdout,
|
| 501 |
+
"stderr": proc.stderr,
|
| 502 |
+
"scores": parsed["scores"],
|
| 503 |
+
"grades": parsed.get("grades"),
|
| 504 |
+
"average": parsed.get("average"),
|
| 505 |
+
"model": parsed.get("model"),
|
| 506 |
+
"metadata": parsed.get("metadata"),
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def main():
|
| 511 |
+
"""Entry point for `dataops-env` script and `openenv serve`."""
|
| 512 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 513 |
+
port = int(os.getenv("PORT", "7860"))
|
| 514 |
+
reload = os.getenv("DEBUG", "").lower() in ("1", "true")
|
| 515 |
+
cwd = Path.cwd().resolve()
|
| 516 |
+
app_target = "app:app" if cwd == SERVER_DIR else "server.app:app"
|
| 517 |
+
app_dir = str(SERVER_DIR if app_target == "app:app" else PROJECT_ROOT)
|
| 518 |
+
uvicorn.run(
|
| 519 |
+
app_target,
|
| 520 |
+
host=host,
|
| 521 |
+
port=port,
|
| 522 |
+
reload=reload,
|
| 523 |
+
reload_dirs=[str(PROJECT_ROOT)] if reload else None,
|
| 524 |
+
ws="wsproto",
|
| 525 |
+
app_dir=app_dir,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
if __name__ == "__main__":
|
| 530 |
+
main()
|
server/dataops_env_environment.py
ADDED
|
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import sqlite3
|
| 9 |
+
import textwrap
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
import uuid
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server import Environment
|
| 17 |
+
from pydantic import ValidationError
|
| 18 |
+
|
| 19 |
+
from data.init_db import WORKSPACE_ROOT, setup_workspace
|
| 20 |
+
from models import (
|
| 21 |
+
PAYLOAD_MODELS,
|
| 22 |
+
DataOpsAction,
|
| 23 |
+
DataOpsObservation,
|
| 24 |
+
DataOpsState,
|
| 25 |
+
ExecuteSQLPayload,
|
| 26 |
+
ReadFilePayload,
|
| 27 |
+
RunScriptPayload,
|
| 28 |
+
SendEmailPayload,
|
| 29 |
+
WriteFilePayload,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .safe_exec import PythonRunResult, run_python_code, run_python_script
|
| 33 |
+
from .task_specs import (
|
| 34 |
+
TASK_ALLOWED_READ_FILES,
|
| 35 |
+
TASK_ALLOWED_RUN_FILES,
|
| 36 |
+
TASK_ALLOWED_WRITE_FILES,
|
| 37 |
+
TASK_EMAIL_ENABLED,
|
| 38 |
+
TASK_IDS,
|
| 39 |
+
TASK_SQL_POLICIES,
|
| 40 |
+
TaskScenarioBundle,
|
| 41 |
+
build_task_scenario,
|
| 42 |
+
normalize_task_3_rows,
|
| 43 |
+
report_matches_expected,
|
| 44 |
+
task_3_data_matches_expected,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
_SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL)
|
| 50 |
+
_SQL_STRING_RE = re.compile(r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"")
|
| 51 |
+
_SQL_TABLE_REF_RE = re.compile(
|
| 52 |
+
r"\b(?:from|join|update|into|delete\s+from)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
| 53 |
+
re.IGNORECASE,
|
| 54 |
+
)
|
| 55 |
+
_SQL_CTE_NAME_RE = re.compile(
|
| 56 |
+
r"(?:\bwith\b|,)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s+as\s*\(",
|
| 57 |
+
re.IGNORECASE,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
MAX_STEPS = 15
|
| 61 |
+
MAX_SQL_ROWS = 500
|
| 62 |
+
MAX_FILE_SIZE = 1_000_000
|
| 63 |
+
DEFAULT_ACTION_TIMEOUT_S = 10.0
|
| 64 |
+
MAX_ACTION_TIMEOUT_S = 30.0
|
| 65 |
+
MAX_STDOUT_CHARS = 50_000
|
| 66 |
+
MAX_STDERR_CHARS = 10_000
|
| 67 |
+
|
| 68 |
+
PENALTY_FAILURE = -0.03
|
| 69 |
+
PENALTY_DESTRUCTIVE = -0.20
|
| 70 |
+
PENALTY_REPEAT = -0.08
|
| 71 |
+
PENALTY_DISALLOWED_TOOL_UNIT = -0.04
|
| 72 |
+
|
| 73 |
+
# Keep milestone bonuses small so the terminal grader remains the dominant signal.
|
| 74 |
+
REWARD_EVENT_VALUES = {
|
| 75 |
+
"t1_inspected_corruption": 0.05,
|
| 76 |
+
"t1_exact_cleanup": 0.04,
|
| 77 |
+
"t2_read_source": 0.04,
|
| 78 |
+
"t2_candidate_compiles": 0.02,
|
| 79 |
+
"t2_verified_fix": 0.03,
|
| 80 |
+
"t3_nonempty_select": 0.03,
|
| 81 |
+
"t3_matching_sql": 0.03,
|
| 82 |
+
"t3_read_formatter_source": 0.02,
|
| 83 |
+
"t3_report_data_verified": 0.03,
|
| 84 |
+
"t3_formatter_compiles": 0.02,
|
| 85 |
+
"t3_report_generated": 0.03,
|
| 86 |
+
"t3_email_verified": 0.02,
|
| 87 |
+
}
|
| 88 |
+
PENALTY_EVENTS = {
|
| 89 |
+
"destructive_sql": PENALTY_DESTRUCTIVE,
|
| 90 |
+
"multiple_emails": -0.08,
|
| 91 |
+
"t2_run_before_read": -0.05,
|
| 92 |
+
"t2_write_before_read": -0.05,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DataOpsEnvironment(Environment[DataOpsAction, DataOpsObservation, DataOpsState]):
|
| 97 |
+
"""Enterprise data pipeline remediation environment (OpenEnv-compliant)."""
|
| 98 |
+
|
| 99 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 100 |
+
|
| 101 |
+
def __init__(self) -> None:
|
| 102 |
+
self._workspace_dir = os.path.join(WORKSPACE_ROOT, "sessions", uuid.uuid4().hex)
|
| 103 |
+
self._db_path = os.path.join(self._workspace_dir, "mock_warehouse.db")
|
| 104 |
+
self._state = DataOpsState()
|
| 105 |
+
self._scenario: TaskScenarioBundle = build_task_scenario(
|
| 106 |
+
"task_1_easy_anomaly", seed=0
|
| 107 |
+
)
|
| 108 |
+
self._evidence: dict[str, Any] = {}
|
| 109 |
+
self._pending_events: list[str] = []
|
| 110 |
+
self.email_outbox: list[dict[str, str]] = []
|
| 111 |
+
self._last_action_key: Optional[str] = None
|
| 112 |
+
self._milestones: set[str] = set()
|
| 113 |
+
self._grader_score = 0.0
|
| 114 |
+
self._disallowed_tool_attempts = 0
|
| 115 |
+
self._lock = threading.Lock()
|
| 116 |
+
|
| 117 |
+
def reset(
|
| 118 |
+
self,
|
| 119 |
+
seed: Optional[int] = None,
|
| 120 |
+
episode_id: Optional[str] = None,
|
| 121 |
+
**kwargs: Any,
|
| 122 |
+
) -> DataOpsObservation:
|
| 123 |
+
task_id: str = kwargs.get("task_id", "task_1_easy_anomaly")
|
| 124 |
+
if task_id not in TASK_IDS:
|
| 125 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
| 126 |
+
|
| 127 |
+
with self._lock:
|
| 128 |
+
self._scenario = build_task_scenario(task_id, seed=seed)
|
| 129 |
+
self._db_path = setup_workspace(
|
| 130 |
+
self._workspace_dir,
|
| 131 |
+
scenario=self._scenario,
|
| 132 |
+
)
|
| 133 |
+
self.email_outbox.clear()
|
| 134 |
+
self._last_action_key = None
|
| 135 |
+
self._milestones.clear()
|
| 136 |
+
self._pending_events = []
|
| 137 |
+
self._disallowed_tool_attempts = 0
|
| 138 |
+
self._evidence = self._initial_evidence()
|
| 139 |
+
self._state = DataOpsState(
|
| 140 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 141 |
+
step_count=0,
|
| 142 |
+
task_id=task_id,
|
| 143 |
+
task_description=self._scenario.description,
|
| 144 |
+
max_steps=MAX_STEPS,
|
| 145 |
+
seed=self._scenario.seed,
|
| 146 |
+
)
|
| 147 |
+
self._grader_score = self._current_task_score()
|
| 148 |
+
|
| 149 |
+
return DataOpsObservation(
|
| 150 |
+
status="success",
|
| 151 |
+
done=False,
|
| 152 |
+
reward=0.0,
|
| 153 |
+
message=f"Environment reset. Task: {self._scenario.description}",
|
| 154 |
+
step_count=0,
|
| 155 |
+
max_steps=MAX_STEPS,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def step(
|
| 159 |
+
self, action: DataOpsAction, timeout_s: Optional[float] = None, **kwargs: Any
|
| 160 |
+
) -> DataOpsObservation:
|
| 161 |
+
del kwargs
|
| 162 |
+
with self._lock:
|
| 163 |
+
return self._step_locked(action, timeout_s)
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def state(self) -> DataOpsState:
|
| 167 |
+
return self._state.model_copy()
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def scenario(self) -> TaskScenarioBundle:
|
| 171 |
+
return self._scenario
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def evidence(self) -> dict[str, Any]:
|
| 175 |
+
return deepcopy(self._evidence)
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def workspace_dir(self) -> str:
|
| 179 |
+
return self._workspace_dir
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def db_path(self) -> str:
|
| 183 |
+
return self._db_path
|
| 184 |
+
|
| 185 |
+
def close(self) -> None:
|
| 186 |
+
if os.path.isdir(self._workspace_dir):
|
| 187 |
+
shutil.rmtree(self._workspace_dir, ignore_errors=True)
|
| 188 |
+
|
| 189 |
+
def _step_locked(
|
| 190 |
+
self, action: DataOpsAction, timeout_s: Optional[float]
|
| 191 |
+
) -> DataOpsObservation:
|
| 192 |
+
if self._state.done:
|
| 193 |
+
return self._obs(
|
| 194 |
+
"error", "Episode is over. Call /reset to start a new one.", done=True
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
model_cls = PAYLOAD_MODELS.get(action.action_type)
|
| 198 |
+
if not model_cls:
|
| 199 |
+
return self._obs("error", f"Unknown action_type: {action.action_type}")
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
payload = model_cls(**action.payload)
|
| 203 |
+
except ValidationError as exc:
|
| 204 |
+
return self._obs(
|
| 205 |
+
"error",
|
| 206 |
+
f"Invalid payload: {exc.error_count()} validation error(s).",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self._pending_events = []
|
| 210 |
+
obs = self._dispatch(action.action_type, payload, timeout_s)
|
| 211 |
+
|
| 212 |
+
reward = self._compute_reward(action, obs)
|
| 213 |
+
self._state.step_count += 1
|
| 214 |
+
self._state.cumulative_reward += reward
|
| 215 |
+
self._state.actions_taken.append(action.action_type)
|
| 216 |
+
self._state.emails_sent = len(self.email_outbox)
|
| 217 |
+
|
| 218 |
+
done = self._state.step_count >= MAX_STEPS or self._task_completed()
|
| 219 |
+
self._state.done = done
|
| 220 |
+
|
| 221 |
+
obs.reward = round(reward, 4)
|
| 222 |
+
obs.done = done
|
| 223 |
+
obs.step_count = self._state.step_count
|
| 224 |
+
obs.max_steps = MAX_STEPS
|
| 225 |
+
return obs
|
| 226 |
+
|
| 227 |
+
def _dispatch(
|
| 228 |
+
self, action_type: str, payload: Any, timeout_s: Optional[float]
|
| 229 |
+
) -> DataOpsObservation:
|
| 230 |
+
handlers = {
|
| 231 |
+
"ExecuteSQL": self._handle_sql,
|
| 232 |
+
"ReadFile": self._handle_read,
|
| 233 |
+
"WriteFile": self._handle_write,
|
| 234 |
+
"RunScript": self._handle_run,
|
| 235 |
+
"SendEmail": self._handle_email,
|
| 236 |
+
}
|
| 237 |
+
return handlers[action_type](payload, timeout_s)
|
| 238 |
+
|
| 239 |
+
def _handle_sql(
|
| 240 |
+
self, payload: ExecuteSQLPayload, timeout_s: Optional[float]
|
| 241 |
+
) -> DataOpsObservation:
|
| 242 |
+
query = payload.query.strip()
|
| 243 |
+
while True:
|
| 244 |
+
q = query.rstrip()
|
| 245 |
+
if not q.endswith(";"):
|
| 246 |
+
break
|
| 247 |
+
query = q[:-1].rstrip()
|
| 248 |
+
statement_type = self._statement_type(query)
|
| 249 |
+
validation_error = self._validate_sql_action(query, statement_type)
|
| 250 |
+
if validation_error:
|
| 251 |
+
return self._obs("error", validation_error)
|
| 252 |
+
|
| 253 |
+
timeout = self._resolve_timeout(timeout_s)
|
| 254 |
+
deadline = time.monotonic() + timeout
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
with sqlite3.connect(self._db_path) as conn:
|
| 258 |
+
conn.set_progress_handler(
|
| 259 |
+
lambda: 1 if time.monotonic() >= deadline else 0,
|
| 260 |
+
1_000,
|
| 261 |
+
)
|
| 262 |
+
conn.row_factory = sqlite3.Row
|
| 263 |
+
cursor = conn.cursor()
|
| 264 |
+
cursor.execute(query)
|
| 265 |
+
|
| 266 |
+
if statement_type in {"SELECT", "WITH"}:
|
| 267 |
+
cols = [c[0] for c in cursor.description or []]
|
| 268 |
+
rows_raw = cursor.fetchmany(MAX_SQL_ROWS + 1)
|
| 269 |
+
if len(rows_raw) > MAX_SQL_ROWS:
|
| 270 |
+
return self._obs(
|
| 271 |
+
"error",
|
| 272 |
+
f"Result exceeds {MAX_SQL_ROWS} rows. Add a LIMIT clause.",
|
| 273 |
+
)
|
| 274 |
+
rows = [dict(zip(cols, row)) for row in rows_raw]
|
| 275 |
+
self._record_sql_select(query, rows)
|
| 276 |
+
return DataOpsObservation(
|
| 277 |
+
status="success",
|
| 278 |
+
sql_results=rows,
|
| 279 |
+
message=f"Query returned {len(rows)} rows.",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
conn.commit()
|
| 283 |
+
self._record_sql_mutation(query, cursor.rowcount)
|
| 284 |
+
return self._obs("success", f"Rows affected: {cursor.rowcount}")
|
| 285 |
+
except sqlite3.Error as exc:
|
| 286 |
+
if "interrupted" in str(exc).lower():
|
| 287 |
+
return self._obs(
|
| 288 |
+
"error", f"SQL execution timed out ({timeout:.1f}s limit)."
|
| 289 |
+
)
|
| 290 |
+
logger.warning("SQL error: %s", exc)
|
| 291 |
+
msg = "SQL execution error. Check your query syntax."
|
| 292 |
+
if self._state.task_id == "task_3_hard_e2e" and re.search(
|
| 293 |
+
r"\bdate\b", query, re.IGNORECASE
|
| 294 |
+
):
|
| 295 |
+
if "report_date" not in query.lower():
|
| 296 |
+
msg += " Hint: table `daily_reports` uses column `report_date` for the calendar date."
|
| 297 |
+
return self._obs("error", msg)
|
| 298 |
+
|
| 299 |
+
def _handle_read(
|
| 300 |
+
self, payload: ReadFilePayload, timeout_s: Optional[float]
|
| 301 |
+
) -> DataOpsObservation:
|
| 302 |
+
del timeout_s
|
| 303 |
+
basename = os.path.basename(payload.filepath)
|
| 304 |
+
if not self._is_allowed_file(TASK_ALLOWED_READ_FILES, basename):
|
| 305 |
+
return self._obs(
|
| 306 |
+
"error", f"Reading {basename} is not allowed for this task."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
safe_path = self._resolve_workspace_path(basename)
|
| 310 |
+
if safe_path is None:
|
| 311 |
+
return self._obs("error", "Resolved file path escapes the workspace.")
|
| 312 |
+
if not os.path.isfile(safe_path):
|
| 313 |
+
return self._obs("error", f"File not found: {basename}")
|
| 314 |
+
if os.path.getsize(safe_path) > MAX_FILE_SIZE:
|
| 315 |
+
return self._obs("error", "File too large to read.")
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
with open(safe_path, encoding="utf-8") as f:
|
| 319 |
+
content = f.read(MAX_FILE_SIZE)
|
| 320 |
+
except OSError:
|
| 321 |
+
return self._obs("error", "Failed to read file.")
|
| 322 |
+
|
| 323 |
+
if (
|
| 324 |
+
self._state.task_id == "task_2_medium_syntax"
|
| 325 |
+
and basename == "broken_pipeline.py"
|
| 326 |
+
):
|
| 327 |
+
self._evidence["task_2"]["read_source"] = True
|
| 328 |
+
self._record_event("t2_read_source")
|
| 329 |
+
if self._state.task_id == "task_3_hard_e2e" and basename == "format_report.py":
|
| 330 |
+
self._evidence["task_3"]["read_formatter_source"] = True
|
| 331 |
+
self._record_event("t3_read_formatter_source")
|
| 332 |
+
return DataOpsObservation(
|
| 333 |
+
status="success",
|
| 334 |
+
stdout=content,
|
| 335 |
+
message=f"Read {len(content)} chars from {basename}",
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def _handle_write(
|
| 339 |
+
self, payload: WriteFilePayload, timeout_s: Optional[float]
|
| 340 |
+
) -> DataOpsObservation:
|
| 341 |
+
del timeout_s
|
| 342 |
+
basename = os.path.basename(payload.filepath)
|
| 343 |
+
if not self._is_allowed_file(TASK_ALLOWED_WRITE_FILES, basename):
|
| 344 |
+
return self._obs(
|
| 345 |
+
"error", f"Writing {basename} is not allowed for this task."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if (
|
| 349 |
+
self._state.task_id == "task_2_medium_syntax"
|
| 350 |
+
and basename == "broken_pipeline.py"
|
| 351 |
+
):
|
| 352 |
+
if not self._evidence["task_2"]["read_source"]:
|
| 353 |
+
self._pending_events.append("t2_write_before_read")
|
| 354 |
+
|
| 355 |
+
safe_path = self._resolve_workspace_path(basename)
|
| 356 |
+
if safe_path is None:
|
| 357 |
+
return self._obs("error", "Resolved file path escapes the workspace.")
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
with open(safe_path, "w", encoding="utf-8") as f:
|
| 361 |
+
f.write(payload.content)
|
| 362 |
+
except OSError:
|
| 363 |
+
return self._obs("error", "Failed to write file.")
|
| 364 |
+
|
| 365 |
+
self._record_write_evidence(basename, payload.content)
|
| 366 |
+
return self._obs("success", f"Wrote {len(payload.content)} chars to {basename}")
|
| 367 |
+
|
| 368 |
+
def _handle_run(
|
| 369 |
+
self, payload: RunScriptPayload, timeout_s: Optional[float]
|
| 370 |
+
) -> DataOpsObservation:
|
| 371 |
+
basename = os.path.basename(payload.filepath)
|
| 372 |
+
if not self._is_allowed_file(TASK_ALLOWED_RUN_FILES, basename):
|
| 373 |
+
return self._obs(
|
| 374 |
+
"error", f"Executing {basename} is not allowed for this task."
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
script_path = self._resolve_workspace_path(basename)
|
| 378 |
+
if script_path is None:
|
| 379 |
+
return self._obs("error", "Resolved script path escapes the workspace.")
|
| 380 |
+
if not os.path.isfile(script_path):
|
| 381 |
+
return self._obs("error", f"Script not found: {basename}")
|
| 382 |
+
|
| 383 |
+
if (
|
| 384 |
+
self._state.task_id == "task_2_medium_syntax"
|
| 385 |
+
and basename == "broken_pipeline.py"
|
| 386 |
+
):
|
| 387 |
+
if not self._evidence["task_2"]["read_source"]:
|
| 388 |
+
self._pending_events.append("t2_run_before_read")
|
| 389 |
+
|
| 390 |
+
timeout = self._resolve_timeout(timeout_s)
|
| 391 |
+
try:
|
| 392 |
+
result = run_python_script(
|
| 393 |
+
basename,
|
| 394 |
+
cwd=self._workspace_dir,
|
| 395 |
+
args=list(payload.args),
|
| 396 |
+
timeout_s=timeout,
|
| 397 |
+
stdout_limit=MAX_STDOUT_CHARS,
|
| 398 |
+
stderr_limit=MAX_STDERR_CHARS,
|
| 399 |
+
)
|
| 400 |
+
except OSError:
|
| 401 |
+
return self._obs("error", "Failed to execute script.")
|
| 402 |
+
|
| 403 |
+
if result.timed_out:
|
| 404 |
+
return self._obs("error", f"Script timed out ({timeout:.1f}s limit).")
|
| 405 |
+
|
| 406 |
+
self._record_run_evidence(basename, payload.args, result)
|
| 407 |
+
status = "success" if result.returncode == 0 else "error"
|
| 408 |
+
return DataOpsObservation(
|
| 409 |
+
status=status,
|
| 410 |
+
stdout=(result.stdout or "")[:MAX_STDOUT_CHARS],
|
| 411 |
+
stderr=(result.stderr or "")[:MAX_STDERR_CHARS],
|
| 412 |
+
message=f"Exit code: {result.returncode}",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def _handle_email(
|
| 416 |
+
self, payload: SendEmailPayload, timeout_s: Optional[float]
|
| 417 |
+
) -> DataOpsObservation:
|
| 418 |
+
del timeout_s
|
| 419 |
+
if self._state.task_id not in TASK_EMAIL_ENABLED:
|
| 420 |
+
self._disallowed_tool_attempts += 1
|
| 421 |
+
self._pending_events.append("disallowed_tool")
|
| 422 |
+
return self._obs(
|
| 423 |
+
"error",
|
| 424 |
+
"Email is not available for this task. Use read_file, write_file, and invoke_python only.",
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
email = {
|
| 428 |
+
"to_email": payload.to_email,
|
| 429 |
+
"subject": payload.subject,
|
| 430 |
+
"body": payload.body,
|
| 431 |
+
}
|
| 432 |
+
self.email_outbox.append(email)
|
| 433 |
+
self._record_email_evidence(email)
|
| 434 |
+
return DataOpsObservation(
|
| 435 |
+
status="success",
|
| 436 |
+
email_delivery_status=f"Queued for {payload.to_email}",
|
| 437 |
+
message=f"Email queued for delivery to {payload.to_email}",
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
def _compute_reward(self, action: DataOpsAction, obs: DataOpsObservation) -> float:
|
| 441 |
+
current_score = self._current_task_score()
|
| 442 |
+
reward = current_score - self._grader_score
|
| 443 |
+
self._grader_score = current_score
|
| 444 |
+
|
| 445 |
+
if obs.status != "success":
|
| 446 |
+
reward += PENALTY_FAILURE
|
| 447 |
+
|
| 448 |
+
action_key = (
|
| 449 |
+
f"{action.action_type}:"
|
| 450 |
+
f"{json.dumps(action.payload, sort_keys=True, ensure_ascii=True)}"
|
| 451 |
+
)
|
| 452 |
+
if action_key == self._last_action_key:
|
| 453 |
+
reward += PENALTY_REPEAT
|
| 454 |
+
self._last_action_key = action_key
|
| 455 |
+
|
| 456 |
+
for event in self._pending_events:
|
| 457 |
+
if event == "disallowed_tool":
|
| 458 |
+
reward += PENALTY_DISALLOWED_TOOL_UNIT * min(
|
| 459 |
+
self._disallowed_tool_attempts, 12
|
| 460 |
+
)
|
| 461 |
+
continue
|
| 462 |
+
if event in PENALTY_EVENTS:
|
| 463 |
+
reward += PENALTY_EVENTS[event]
|
| 464 |
+
continue
|
| 465 |
+
reward += self._award_event(event)
|
| 466 |
+
return reward
|
| 467 |
+
|
| 468 |
+
def _award_event(self, event: str) -> float:
|
| 469 |
+
if event in self._milestones:
|
| 470 |
+
return 0.0
|
| 471 |
+
self._milestones.add(event)
|
| 472 |
+
return REWARD_EVENT_VALUES.get(event, 0.0)
|
| 473 |
+
|
| 474 |
+
def _initial_evidence(self) -> dict[str, Any]:
|
| 475 |
+
return {
|
| 476 |
+
"task_1": {
|
| 477 |
+
"inspected_corrupted_rows": False,
|
| 478 |
+
"exact_cleanup": False,
|
| 479 |
+
"destructive_sql_attempted": False,
|
| 480 |
+
},
|
| 481 |
+
"task_2": {
|
| 482 |
+
"read_source": False,
|
| 483 |
+
"candidate_compiles": False,
|
| 484 |
+
"verified_fix": False,
|
| 485 |
+
},
|
| 486 |
+
"task_3": {
|
| 487 |
+
"matching_sql_executed": False,
|
| 488 |
+
"last_matching_sql_rows": [],
|
| 489 |
+
"read_formatter_source": False,
|
| 490 |
+
"report_data_matches_sql": False,
|
| 491 |
+
"formatter_compiles": False,
|
| 492 |
+
"format_output_matches_expected": False,
|
| 493 |
+
"last_formatter_output": "",
|
| 494 |
+
"email_matches_formatter_output": False,
|
| 495 |
+
"single_email_sent": True,
|
| 496 |
+
},
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
def _record_sql_select(self, query: str, rows: list[dict[str, Any]]) -> None:
|
| 500 |
+
if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
|
| 501 |
+
row_ids = {int(row.get("id")) for row in rows if row.get("id") is not None}
|
| 502 |
+
corrupted = set(self._scenario.task_1.corrupted_row_ids)
|
| 503 |
+
if row_ids & corrupted:
|
| 504 |
+
self._evidence["task_1"]["inspected_corrupted_rows"] = True
|
| 505 |
+
self._record_event("t1_inspected_corruption")
|
| 506 |
+
|
| 507 |
+
if self._scenario.task_3 and self._state.task_id == "task_3_hard_e2e":
|
| 508 |
+
normalised_rows = normalize_task_3_rows(rows, require_headcount=True)
|
| 509 |
+
expected_rows = list(self._scenario.task_3.expected_rows)
|
| 510 |
+
if task_3_data_matches_expected(
|
| 511 |
+
normalised_rows,
|
| 512 |
+
expected_rows,
|
| 513 |
+
require_headcount=True,
|
| 514 |
+
):
|
| 515 |
+
self._evidence["task_3"]["matching_sql_executed"] = True
|
| 516 |
+
self._evidence["task_3"]["last_matching_sql_rows"] = normalised_rows
|
| 517 |
+
self._record_event("t3_matching_sql")
|
| 518 |
+
elif rows:
|
| 519 |
+
self._record_event("t3_nonempty_select")
|
| 520 |
+
|
| 521 |
+
def _record_sql_mutation(self, query: str, rowcount: int) -> None:
|
| 522 |
+
del rowcount
|
| 523 |
+
if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
|
| 524 |
+
exact_rows = self._current_transactions_rows()
|
| 525 |
+
expected_rows = list(self._scenario.task_1.expected_rows)
|
| 526 |
+
expected_by_id = {row["id"]: row for row in expected_rows}
|
| 527 |
+
actual_by_id = {row["id"]: row for row in exact_rows}
|
| 528 |
+
valid_rows_lost = any(
|
| 529 |
+
row_id not in actual_by_id for row_id in expected_by_id
|
| 530 |
+
)
|
| 531 |
+
valid_rows_changed = any(
|
| 532 |
+
actual_by_id[row_id] != expected_row
|
| 533 |
+
for row_id, expected_row in expected_by_id.items()
|
| 534 |
+
if row_id in actual_by_id
|
| 535 |
+
)
|
| 536 |
+
if exact_rows == expected_rows:
|
| 537 |
+
self._evidence["task_1"]["exact_cleanup"] = True
|
| 538 |
+
self._record_event("t1_exact_cleanup")
|
| 539 |
+
elif valid_rows_lost or valid_rows_changed:
|
| 540 |
+
self._evidence["task_1"]["destructive_sql_attempted"] = True
|
| 541 |
+
self._pending_events.append("destructive_sql")
|
| 542 |
+
|
| 543 |
+
def _record_write_evidence(self, basename: str, content: str) -> None:
|
| 544 |
+
if (
|
| 545 |
+
self._state.task_id == "task_2_medium_syntax"
|
| 546 |
+
and basename == "broken_pipeline.py"
|
| 547 |
+
):
|
| 548 |
+
compiles = self._script_compiles(content, basename)
|
| 549 |
+
self._evidence["task_2"]["candidate_compiles"] = compiles
|
| 550 |
+
if compiles:
|
| 551 |
+
self._record_event("t2_candidate_compiles")
|
| 552 |
+
return
|
| 553 |
+
|
| 554 |
+
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
|
| 555 |
+
return
|
| 556 |
+
|
| 557 |
+
task_3 = self._evidence["task_3"]
|
| 558 |
+
if basename == "report_data.json":
|
| 559 |
+
try:
|
| 560 |
+
payload = json.loads(content)
|
| 561 |
+
except json.JSONDecodeError:
|
| 562 |
+
task_3["report_data_matches_sql"] = False
|
| 563 |
+
return
|
| 564 |
+
if not isinstance(payload, list):
|
| 565 |
+
task_3["report_data_matches_sql"] = False
|
| 566 |
+
return
|
| 567 |
+
normalised_rows = normalize_task_3_rows(payload, require_headcount=True)
|
| 568 |
+
expected_rows = list(self._scenario.task_3.expected_rows)
|
| 569 |
+
last_sql_rows = task_3.get("last_matching_sql_rows", [])
|
| 570 |
+
matches_sql = bool(last_sql_rows) and normalised_rows == last_sql_rows
|
| 571 |
+
matches_expected = task_3_data_matches_expected(
|
| 572 |
+
normalised_rows,
|
| 573 |
+
expected_rows,
|
| 574 |
+
require_headcount=True,
|
| 575 |
+
)
|
| 576 |
+
task_3["report_data_matches_sql"] = matches_sql and matches_expected
|
| 577 |
+
if task_3["report_data_matches_sql"]:
|
| 578 |
+
self._record_event("t3_report_data_verified")
|
| 579 |
+
return
|
| 580 |
+
|
| 581 |
+
if basename == "format_report.py":
|
| 582 |
+
compiles = self._script_compiles(content, basename)
|
| 583 |
+
task_3["formatter_compiles"] = compiles
|
| 584 |
+
if compiles:
|
| 585 |
+
self._record_event("t3_formatter_compiles")
|
| 586 |
+
|
| 587 |
+
def _record_run_evidence(
|
| 588 |
+
self,
|
| 589 |
+
basename: str,
|
| 590 |
+
args: list[str],
|
| 591 |
+
result: PythonRunResult,
|
| 592 |
+
) -> None:
|
| 593 |
+
if (
|
| 594 |
+
self._state.task_id == "task_2_medium_syntax"
|
| 595 |
+
and basename == "broken_pipeline.py"
|
| 596 |
+
):
|
| 597 |
+
if result.returncode == 0 and self._task_2_candidate_is_functional():
|
| 598 |
+
self._evidence["task_2"]["verified_fix"] = True
|
| 599 |
+
self._record_event("t2_verified_fix")
|
| 600 |
+
return
|
| 601 |
+
|
| 602 |
+
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
|
| 603 |
+
return
|
| 604 |
+
if basename != "format_report.py":
|
| 605 |
+
return
|
| 606 |
+
|
| 607 |
+
task_3 = self._evidence["task_3"]
|
| 608 |
+
stdout = (result.stdout or "").strip()
|
| 609 |
+
if (
|
| 610 |
+
result.returncode == 0
|
| 611 |
+
and self._task_3_args_reference_report_data(args)
|
| 612 |
+
and task_3.get("report_data_matches_sql")
|
| 613 |
+
and report_matches_expected(
|
| 614 |
+
stdout,
|
| 615 |
+
self._scenario.task_3.expected_rows,
|
| 616 |
+
self._scenario.task_3.target_date,
|
| 617 |
+
)
|
| 618 |
+
):
|
| 619 |
+
task_3["format_output_matches_expected"] = True
|
| 620 |
+
task_3["last_formatter_output"] = stdout
|
| 621 |
+
self._record_event("t3_report_generated")
|
| 622 |
+
|
| 623 |
+
def _record_email_evidence(self, email: dict[str, str]) -> None:
|
| 624 |
+
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
|
| 625 |
+
return
|
| 626 |
+
|
| 627 |
+
task_3 = self._evidence["task_3"]
|
| 628 |
+
if len(self.email_outbox) > 1:
|
| 629 |
+
task_3["single_email_sent"] = False
|
| 630 |
+
self._pending_events.append("multiple_emails")
|
| 631 |
+
|
| 632 |
+
if (
|
| 633 |
+
task_3.get("format_output_matches_expected")
|
| 634 |
+
and task_3.get("single_email_sent")
|
| 635 |
+
and email.get("to_email") == self._scenario.task_3.recipient
|
| 636 |
+
and email.get("subject") == self._scenario.task_3.subject
|
| 637 |
+
and email.get("body", "").strip()
|
| 638 |
+
== str(task_3.get("last_formatter_output", "")).strip()
|
| 639 |
+
):
|
| 640 |
+
task_3["email_matches_formatter_output"] = True
|
| 641 |
+
self._record_event("t3_email_verified")
|
| 642 |
+
|
| 643 |
+
def _task_2_candidate_is_functional(self) -> bool:
|
| 644 |
+
if not self._scenario.task_2:
|
| 645 |
+
return False
|
| 646 |
+
wrapper = textwrap.dedent(
|
| 647 |
+
f"""
|
| 648 |
+
import importlib.util
|
| 649 |
+
import json
|
| 650 |
+
|
| 651 |
+
spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
|
| 652 |
+
module = importlib.util.module_from_spec(spec)
|
| 653 |
+
assert spec.loader is not None
|
| 654 |
+
spec.loader.exec_module(module)
|
| 655 |
+
|
| 656 |
+
cases = {json.dumps(self._scenario.task_2.hidden_cases)}
|
| 657 |
+
results = [module.process_data_stream(case) for case in cases]
|
| 658 |
+
print("__RESULT__=" + json.dumps(results))
|
| 659 |
+
"""
|
| 660 |
+
)
|
| 661 |
+
try:
|
| 662 |
+
result = run_python_code(
|
| 663 |
+
wrapper,
|
| 664 |
+
cwd=self._workspace_dir,
|
| 665 |
+
timeout_s=DEFAULT_ACTION_TIMEOUT_S,
|
| 666 |
+
stdout_limit=MAX_STDOUT_CHARS,
|
| 667 |
+
stderr_limit=MAX_STDERR_CHARS,
|
| 668 |
+
)
|
| 669 |
+
except Exception:
|
| 670 |
+
return False
|
| 671 |
+
payload = next(
|
| 672 |
+
(
|
| 673 |
+
line[len("__RESULT__=") :]
|
| 674 |
+
for line in result.stdout.splitlines()
|
| 675 |
+
if line.startswith("__RESULT__=")
|
| 676 |
+
),
|
| 677 |
+
"",
|
| 678 |
+
)
|
| 679 |
+
try:
|
| 680 |
+
parsed = json.loads(payload) if payload else None
|
| 681 |
+
except json.JSONDecodeError:
|
| 682 |
+
parsed = None
|
| 683 |
+
expected = [list(batch) for batch in self._scenario.task_2.hidden_expected]
|
| 684 |
+
return result.returncode == 0 and parsed == expected
|
| 685 |
+
|
| 686 |
+
def _task_3_args_reference_report_data(self, args: list[str]) -> bool:
|
| 687 |
+
if len(args) != 1:
|
| 688 |
+
return False
|
| 689 |
+
|
| 690 |
+
expected_path = self._resolve_workspace_path("report_data.json")
|
| 691 |
+
if expected_path is None:
|
| 692 |
+
return False
|
| 693 |
+
|
| 694 |
+
candidate = args[0]
|
| 695 |
+
if os.path.isabs(candidate):
|
| 696 |
+
resolved = os.path.realpath(candidate)
|
| 697 |
+
else:
|
| 698 |
+
resolved = os.path.realpath(os.path.join(self._workspace_dir, candidate))
|
| 699 |
+
return resolved == expected_path
|
| 700 |
+
|
| 701 |
+
def _current_task_score(self) -> float:
|
| 702 |
+
if not self._state.task_id:
|
| 703 |
+
return 0.0
|
| 704 |
+
try:
|
| 705 |
+
from .grading import evaluate_task
|
| 706 |
+
|
| 707 |
+
return float(evaluate_task(self._state.task_id, self).get("score", 0.0))
|
| 708 |
+
except Exception:
|
| 709 |
+
logger.exception(
|
| 710 |
+
"Failed to compute current grader score for reward shaping."
|
| 711 |
+
)
|
| 712 |
+
return self._grader_score
|
| 713 |
+
|
| 714 |
+
def _script_compiles(self, content: str, filename: str) -> bool:
|
| 715 |
+
try:
|
| 716 |
+
compile(content, filename, "exec")
|
| 717 |
+
except SyntaxError:
|
| 718 |
+
return False
|
| 719 |
+
return True
|
| 720 |
+
|
| 721 |
+
def _task_completed(self) -> bool:
|
| 722 |
+
if self._state.task_id == "task_1_easy_anomaly" and self._scenario.task_1:
|
| 723 |
+
return self._current_transactions_rows() == list(
|
| 724 |
+
self._scenario.task_1.expected_rows
|
| 725 |
+
)
|
| 726 |
+
if self._state.task_id == "task_2_medium_syntax":
|
| 727 |
+
# Terminal grader can be <1.0 even when verified_fix (visible/hidden/provenance split).
|
| 728 |
+
return self._grader_score >= 1.0
|
| 729 |
+
if self._state.task_id == "task_3_hard_e2e":
|
| 730 |
+
# Evidence flags can be partially true while component-weighted grader is still <1.0.
|
| 731 |
+
return self._grader_score >= 1.0
|
| 732 |
+
return False
|
| 733 |
+
|
| 734 |
+
def _current_transactions_rows(self) -> list[dict[str, Any]]:
|
| 735 |
+
with sqlite3.connect(self._db_path) as conn:
|
| 736 |
+
conn.row_factory = sqlite3.Row
|
| 737 |
+
rows = conn.execute(
|
| 738 |
+
"SELECT id, user_id, amount, status FROM transactions ORDER BY id"
|
| 739 |
+
).fetchall()
|
| 740 |
+
return [
|
| 741 |
+
{
|
| 742 |
+
"id": int(row["id"]),
|
| 743 |
+
"user_id": int(row["user_id"]),
|
| 744 |
+
"amount": None
|
| 745 |
+
if row["amount"] is None
|
| 746 |
+
else round(float(row["amount"]), 2),
|
| 747 |
+
"status": str(row["status"]),
|
| 748 |
+
}
|
| 749 |
+
for row in rows
|
| 750 |
+
]
|
| 751 |
+
|
| 752 |
+
def _record_event(self, event: str) -> None:
|
| 753 |
+
self._pending_events.append(event)
|
| 754 |
+
|
| 755 |
+
def _resolve_timeout(self, timeout_s: Optional[float]) -> float:
|
| 756 |
+
if timeout_s is None:
|
| 757 |
+
return DEFAULT_ACTION_TIMEOUT_S
|
| 758 |
+
return max(0.1, min(float(timeout_s), MAX_ACTION_TIMEOUT_S))
|
| 759 |
+
|
| 760 |
+
def _is_allowed_file(
|
| 761 |
+
self, allowed_registry: dict[str, frozenset[str]], basename: str
|
| 762 |
+
) -> bool:
|
| 763 |
+
return basename in allowed_registry.get(self._state.task_id, frozenset())
|
| 764 |
+
|
| 765 |
+
def _resolve_workspace_path(self, basename: str) -> Optional[str]:
|
| 766 |
+
workspace_root = os.path.realpath(self._workspace_dir)
|
| 767 |
+
candidate = os.path.realpath(os.path.join(self._workspace_dir, basename))
|
| 768 |
+
if candidate == workspace_root:
|
| 769 |
+
return None
|
| 770 |
+
if not candidate.startswith(f"{workspace_root}{os.sep}"):
|
| 771 |
+
return None
|
| 772 |
+
return candidate
|
| 773 |
+
|
| 774 |
+
def _statement_type(self, query: str) -> str:
|
| 775 |
+
parts = query.split(None, 1)
|
| 776 |
+
return parts[0].upper() if parts else ""
|
| 777 |
+
|
| 778 |
+
def _validate_sql_action(self, query: str, statement_type: str) -> Optional[str]:
|
| 779 |
+
if not query:
|
| 780 |
+
return "SQL query cannot be empty."
|
| 781 |
+
|
| 782 |
+
policy = TASK_SQL_POLICIES.get(self._state.task_id)
|
| 783 |
+
if policy is None:
|
| 784 |
+
return "SQL is not available for the active task."
|
| 785 |
+
if statement_type not in policy.allowed_commands:
|
| 786 |
+
allowed = ", ".join(sorted(policy.allowed_commands))
|
| 787 |
+
return f"Only {allowed} statements are allowed for this task."
|
| 788 |
+
|
| 789 |
+
sanitized = self._strip_sql_literals_and_comments(query)
|
| 790 |
+
normalized = " ".join(sanitized.split())
|
| 791 |
+
lowered = normalized.lower()
|
| 792 |
+
|
| 793 |
+
if ";" in normalized:
|
| 794 |
+
return "Only a single SQL statement is allowed."
|
| 795 |
+
if any(
|
| 796 |
+
token in lowered
|
| 797 |
+
for token in ("pragma", "attach", "detach", "sqlite_", "alter ", "drop ")
|
| 798 |
+
):
|
| 799 |
+
return "Query contains disallowed SQL constructs."
|
| 800 |
+
if statement_type == "DELETE" and not re.match(
|
| 801 |
+
rf"^delete\s+from\s+{re.escape(policy.required_table)}\s+where\b",
|
| 802 |
+
lowered,
|
| 803 |
+
):
|
| 804 |
+
return f"DELETE statements must target '{policy.required_table}' with an explicit WHERE clause."
|
| 805 |
+
|
| 806 |
+
cte_names = self._extract_cte_names(normalized)
|
| 807 |
+
table_refs = self._extract_sql_table_refs(normalized)
|
| 808 |
+
if policy.required_table not in table_refs:
|
| 809 |
+
return f"Query must target the '{policy.required_table}' table."
|
| 810 |
+
|
| 811 |
+
allowed_refs = {policy.required_table, *cte_names}
|
| 812 |
+
disallowed = sorted(ref for ref in table_refs if ref not in allowed_refs)
|
| 813 |
+
if disallowed:
|
| 814 |
+
return f"Query references disallowed table(s): {', '.join(disallowed)}."
|
| 815 |
+
return None
|
| 816 |
+
|
| 817 |
+
def _strip_sql_literals_and_comments(self, query: str) -> str:
|
| 818 |
+
without_comments = _SQL_COMMENT_RE.sub(" ", query)
|
| 819 |
+
return _SQL_STRING_RE.sub("''", without_comments)
|
| 820 |
+
|
| 821 |
+
def _extract_cte_names(self, query: str) -> set[str]:
|
| 822 |
+
lowered = query.lower().lstrip()
|
| 823 |
+
if not lowered.startswith("with "):
|
| 824 |
+
return set()
|
| 825 |
+
return {match.group(1).lower() for match in _SQL_CTE_NAME_RE.finditer(query)}
|
| 826 |
+
|
| 827 |
+
def _extract_sql_table_refs(self, query: str) -> set[str]:
|
| 828 |
+
return {match.group(1).lower() for match in _SQL_TABLE_REF_RE.finditer(query)}
|
| 829 |
+
|
| 830 |
+
def _obs(
|
| 831 |
+
self, status: str, message: str, *, done: bool = False
|
| 832 |
+
) -> DataOpsObservation:
|
| 833 |
+
return DataOpsObservation(
|
| 834 |
+
status=status,
|
| 835 |
+
message=message,
|
| 836 |
+
step_count=self._state.step_count,
|
| 837 |
+
max_steps=MAX_STEPS,
|
| 838 |
+
done=done,
|
| 839 |
+
)
|
server/grading.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Terminal graders for the seeded DataOpsEnv benchmark."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sqlite3
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from server.dataops_env_environment import DataOpsEnvironment
|
| 13 |
+
from server.safe_exec import run_python_code, run_python_script
|
| 14 |
+
from server.task_specs import (
|
| 15 |
+
build_task_3_report,
|
| 16 |
+
normalize_task_2_output_rows,
|
| 17 |
+
normalize_task_3_rows,
|
| 18 |
+
report_matches_expected,
|
| 19 |
+
task_3_data_matches_expected,
|
| 20 |
+
task_3_semantic_match_fraction_rows,
|
| 21 |
+
task_3_semantic_match_fraction_text,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
SCRIPT_TIMEOUT_S = 10
|
| 26 |
+
INTERNAL_STDOUT_LIMIT = 50_000
|
| 27 |
+
INTERNAL_STDERR_LIMIT = 10_000
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def evaluate_task(task_id: str, env: DataOpsEnvironment) -> dict[str, Any]:
|
| 31 |
+
graders = {
|
| 32 |
+
"task_1_easy_anomaly": _grade_task_1,
|
| 33 |
+
"task_2_medium_syntax": _grade_task_2,
|
| 34 |
+
"task_3_hard_e2e": _grade_task_3,
|
| 35 |
+
}
|
| 36 |
+
grader = graders.get(task_id)
|
| 37 |
+
if grader is None:
|
| 38 |
+
return {"task_id": task_id, "score": 0.0, "details": {"error": "Unknown task"}}
|
| 39 |
+
|
| 40 |
+
score, details = grader(env)
|
| 41 |
+
return {"task_id": task_id, "score": round(score, 2), "details": details}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _grade_task_1(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
|
| 45 |
+
if env.scenario.task_1 is None:
|
| 46 |
+
return 0.0, {"error": "Task 1 scenario missing."}
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
actual_rows = _current_transactions_rows(env.db_path)
|
| 50 |
+
except Exception:
|
| 51 |
+
logger.exception("Task 1 grading error")
|
| 52 |
+
return 0.0, {"error": "Internal grading error."}
|
| 53 |
+
|
| 54 |
+
expected_rows = list(env.scenario.task_1.expected_rows)
|
| 55 |
+
corrupted_ids = set(env.scenario.task_1.corrupted_row_ids)
|
| 56 |
+
actual_ids = {row["id"] for row in actual_rows}
|
| 57 |
+
expected_ids = {row["id"] for row in expected_rows}
|
| 58 |
+
corrupted_remaining = sorted(actual_ids & corrupted_ids)
|
| 59 |
+
rewritten_corrupted = [
|
| 60 |
+
row for row in actual_rows if row["id"] in corrupted_ids and row["amount"] is not None
|
| 61 |
+
]
|
| 62 |
+
valid_rows_intact = all(
|
| 63 |
+
any(actual == expected for actual in actual_rows) for expected in expected_rows
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
details: dict[str, Any] = {
|
| 67 |
+
"expected_row_ids": sorted(expected_ids),
|
| 68 |
+
"actual_row_ids": sorted(actual_ids),
|
| 69 |
+
"corrupted_row_ids": sorted(corrupted_ids),
|
| 70 |
+
"corrupted_remaining": corrupted_remaining,
|
| 71 |
+
"valid_rows_intact": valid_rows_intact,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if actual_rows == expected_rows:
|
| 75 |
+
details["reason"] = "Perfect - corrupted rows were deleted and all valid rows were preserved."
|
| 76 |
+
details["components"] = {
|
| 77 |
+
"exact_cleanup": {"score": 1.0, "max": 1.0, "passed": True},
|
| 78 |
+
}
|
| 79 |
+
return 1.0, details
|
| 80 |
+
|
| 81 |
+
if rewritten_corrupted:
|
| 82 |
+
details["reason"] = "Corrupted rows were rewritten instead of being deleted."
|
| 83 |
+
details["components"] = {
|
| 84 |
+
"exact_cleanup": {"score": 0.0, "max": 1.0, "passed": False},
|
| 85 |
+
}
|
| 86 |
+
return 0.0, details
|
| 87 |
+
|
| 88 |
+
if valid_rows_intact and corrupted_remaining:
|
| 89 |
+
fraction_removed = 1.0 - (len(corrupted_remaining) / max(len(corrupted_ids), 1))
|
| 90 |
+
score = round(0.25 * max(fraction_removed, 0.0), 4)
|
| 91 |
+
details["reason"] = "Some corrupted rows were removed, but cleanup is incomplete."
|
| 92 |
+
details["components"] = {
|
| 93 |
+
"partial_cleanup": {"score": score, "max": 0.25, "passed": False},
|
| 94 |
+
}
|
| 95 |
+
return score, details
|
| 96 |
+
|
| 97 |
+
details["reason"] = "The transaction table does not match the required cleaned state."
|
| 98 |
+
details["components"] = {
|
| 99 |
+
"exact_cleanup": {"score": 0.0, "max": 1.0, "passed": False},
|
| 100 |
+
}
|
| 101 |
+
return 0.0, details
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _grade_task_2(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
|
| 105 |
+
if env.scenario.task_2 is None:
|
| 106 |
+
return 0.0, {"error": "Task 2 scenario missing."}
|
| 107 |
+
|
| 108 |
+
script = os.path.join(env.workspace_dir, "broken_pipeline.py")
|
| 109 |
+
if not os.path.isfile(script):
|
| 110 |
+
return 0.0, {
|
| 111 |
+
"reason": "broken_pipeline.py not found.",
|
| 112 |
+
"components": {
|
| 113 |
+
"script_present": {"score": 0.0, "max": 1.0, "passed": False},
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
with open(script, encoding="utf-8") as f:
|
| 119 |
+
source = f.read()
|
| 120 |
+
static = _inspect_task_2_source(source)
|
| 121 |
+
main_result = run_python_script(
|
| 122 |
+
"broken_pipeline.py",
|
| 123 |
+
cwd=env.workspace_dir,
|
| 124 |
+
args=[],
|
| 125 |
+
timeout_s=SCRIPT_TIMEOUT_S,
|
| 126 |
+
stdout_limit=INTERNAL_STDOUT_LIMIT,
|
| 127 |
+
stderr_limit=INTERNAL_STDERR_LIMIT,
|
| 128 |
+
)
|
| 129 |
+
visible_result = _run_task_2_case_check(
|
| 130 |
+
env.workspace_dir,
|
| 131 |
+
env.scenario.task_2.visible_batch,
|
| 132 |
+
env.scenario.task_2.visible_expected,
|
| 133 |
+
)
|
| 134 |
+
hidden_result = _run_task_2_hidden_tests(
|
| 135 |
+
env.workspace_dir,
|
| 136 |
+
env.scenario.task_2.hidden_cases,
|
| 137 |
+
env.scenario.task_2.hidden_expected,
|
| 138 |
+
)
|
| 139 |
+
except Exception:
|
| 140 |
+
logger.exception("Task 2 grading error")
|
| 141 |
+
return 0.0, {"error": "Internal grading error."}
|
| 142 |
+
|
| 143 |
+
if main_result.timed_out or visible_result["timed_out"] or hidden_result["timed_out"]:
|
| 144 |
+
return 0.0, {"reason": "Script timed out.", "components": {}}
|
| 145 |
+
|
| 146 |
+
hidden_score = round(0.60 * hidden_result["pass_fraction"], 4)
|
| 147 |
+
visible_score = 0.25 if visible_result["passed"] and main_result.returncode == 0 else 0.0
|
| 148 |
+
execution_score = 0.15 if env.evidence.get("task_2", {}).get("verified_fix") else 0.0
|
| 149 |
+
components: dict[str, Any] = {
|
| 150 |
+
"hidden_functional": {
|
| 151 |
+
"score": hidden_score,
|
| 152 |
+
"max": 0.60,
|
| 153 |
+
"passed": hidden_result["passed"],
|
| 154 |
+
},
|
| 155 |
+
"visible_pipeline": {
|
| 156 |
+
"score": visible_score,
|
| 157 |
+
"max": 0.25,
|
| 158 |
+
"passed": visible_result["passed"] and main_result.returncode == 0,
|
| 159 |
+
},
|
| 160 |
+
"execution_provenance": {
|
| 161 |
+
"score": execution_score,
|
| 162 |
+
"max": 0.15,
|
| 163 |
+
"passed": bool(env.evidence.get("task_2", {}).get("verified_fix")),
|
| 164 |
+
},
|
| 165 |
+
}
|
| 166 |
+
score = round(sum(component["score"] for component in components.values()), 4)
|
| 167 |
+
details = {
|
| 168 |
+
"main_exit_code": main_result.returncode,
|
| 169 |
+
"main_stdout": main_result.stdout[:500],
|
| 170 |
+
"main_stderr": main_result.stderr[:500],
|
| 171 |
+
"visible_batch_ok": visible_result["passed"],
|
| 172 |
+
"hidden_tests_passed": hidden_result["passed"],
|
| 173 |
+
"hidden_pass_fraction": hidden_result["pass_fraction"],
|
| 174 |
+
"hidden_case_passes": hidden_result["case_passes"],
|
| 175 |
+
"static_checks": static,
|
| 176 |
+
"components": components,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
if score == 1.0:
|
| 180 |
+
details["reason"] = "Seeded hidden tests and the visible verification run both pass."
|
| 181 |
+
elif hidden_result["passed"] and main_result.returncode == 0:
|
| 182 |
+
details["reason"] = "The ETL transform is correct, but the agent never verified it through the run action."
|
| 183 |
+
elif hidden_result["pass_fraction"] > 0 and main_result.returncode == 0:
|
| 184 |
+
details["reason"] = "The repair improves the ETL transform, but it still fails some seeded cases."
|
| 185 |
+
elif hidden_result["pass_fraction"] > 0:
|
| 186 |
+
details["reason"] = "The core transform improved, but the runnable script entrypoint still drifts."
|
| 187 |
+
elif main_result.returncode == 0:
|
| 188 |
+
details["reason"] = "The script runs, but it does not yet produce the required normalized records."
|
| 189 |
+
else:
|
| 190 |
+
details["reason"] = "The repair is still incorrect or incomplete."
|
| 191 |
+
return score, details
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _grade_task_3(env: DataOpsEnvironment) -> tuple[float, dict[str, Any]]:
|
| 195 |
+
if env.scenario.task_3 is None:
|
| 196 |
+
return 0.0, {"error": "Task 3 scenario missing."}
|
| 197 |
+
|
| 198 |
+
scenario = env.scenario.task_3
|
| 199 |
+
evidence = env.evidence.get("task_3", {})
|
| 200 |
+
expected_rows = list(scenario.expected_rows)
|
| 201 |
+
expected_report = build_task_3_report(expected_rows, scenario.target_date)
|
| 202 |
+
|
| 203 |
+
report_data = _load_task_3_data(env.workspace_dir, expected_rows)
|
| 204 |
+
formatter = _run_task_3_formatter(env.workspace_dir, expected_rows, scenario.target_date)
|
| 205 |
+
email = _score_task_3_email(env, expected_report)
|
| 206 |
+
|
| 207 |
+
report_exact_and_proven = bool(
|
| 208 |
+
report_data["matches_expected"] and evidence.get("report_data_matches_sql")
|
| 209 |
+
)
|
| 210 |
+
formatter_exact_and_proven = bool(
|
| 211 |
+
formatter["matches_expected"] and evidence.get("format_output_matches_expected")
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
components: dict[str, Any] = {
|
| 215 |
+
"sql_provenance": {
|
| 216 |
+
"score": 0.20 if evidence.get("matching_sql_executed") else 0.0,
|
| 217 |
+
"max": 0.20,
|
| 218 |
+
"passed": bool(evidence.get("matching_sql_executed")),
|
| 219 |
+
},
|
| 220 |
+
"report_data": {
|
| 221 |
+
"score": 0.20 if report_exact_and_proven else 0.05 if report_data["matches_expected"] else 0.0,
|
| 222 |
+
"max": 0.20,
|
| 223 |
+
"passed": report_exact_and_proven,
|
| 224 |
+
},
|
| 225 |
+
"formatter": {
|
| 226 |
+
"score": 0.25 if formatter_exact_and_proven else 0.05 if formatter["runs"] else 0.0,
|
| 227 |
+
"max": 0.25,
|
| 228 |
+
"passed": formatter_exact_and_proven,
|
| 229 |
+
},
|
| 230 |
+
"email": {
|
| 231 |
+
"score": email["score"],
|
| 232 |
+
"max": 0.35,
|
| 233 |
+
"passed": email["passed"],
|
| 234 |
+
},
|
| 235 |
+
}
|
| 236 |
+
score = round(sum(component["score"] for component in components.values()), 4)
|
| 237 |
+
details: dict[str, Any] = {
|
| 238 |
+
"target_date": scenario.target_date,
|
| 239 |
+
"expected_recipient": scenario.recipient,
|
| 240 |
+
"expected_subject": scenario.subject,
|
| 241 |
+
"report_data": report_data["details"],
|
| 242 |
+
"formatter": formatter["details"],
|
| 243 |
+
"email": email["details"],
|
| 244 |
+
"evidence": evidence,
|
| 245 |
+
"components": components,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if score == 1.0:
|
| 249 |
+
details["reason"] = "Perfect - the seeded SQL slice, JSON output, formatter run, and final email all align."
|
| 250 |
+
elif score >= 0.55:
|
| 251 |
+
details["reason"] = "Strong progress - some of the seeded workflow is correct, but provenance is incomplete."
|
| 252 |
+
elif score > 0:
|
| 253 |
+
details["reason"] = "Partial progress - artifacts exist, but the end-to-end incident workflow is not proven."
|
| 254 |
+
else:
|
| 255 |
+
details["reason"] = "The seeded hard task is still unsolved."
|
| 256 |
+
return score, details
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _inspect_task_2_source(source: str) -> dict[str, Any]:
|
| 260 |
+
try:
|
| 261 |
+
tree = ast.parse(source)
|
| 262 |
+
except SyntaxError as exc:
|
| 263 |
+
return {"passed": False, "error": str(exc), "has_function": False}
|
| 264 |
+
|
| 265 |
+
functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
|
| 266 |
+
target = next((node for node in functions if node.name == "process_data_stream"), None)
|
| 267 |
+
passed = target is not None and len(target.args.args) == 1
|
| 268 |
+
return {"passed": passed, "has_function": target is not None}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _run_task_2_case_check(
|
| 272 |
+
workspace_dir: str,
|
| 273 |
+
batch: tuple[dict[str, Any], ...],
|
| 274 |
+
expected: tuple[dict[str, Any], ...],
|
| 275 |
+
) -> dict[str, Any]:
|
| 276 |
+
wrapper = f"""
|
| 277 |
+
import importlib.util
|
| 278 |
+
import json
|
| 279 |
+
|
| 280 |
+
spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
|
| 281 |
+
module = importlib.util.module_from_spec(spec)
|
| 282 |
+
assert spec.loader is not None
|
| 283 |
+
spec.loader.exec_module(module)
|
| 284 |
+
|
| 285 |
+
batch = {json.dumps(list(batch))}
|
| 286 |
+
results = module.process_data_stream(batch)
|
| 287 |
+
print("__RESULT__=" + json.dumps(results))
|
| 288 |
+
"""
|
| 289 |
+
result = run_python_code(
|
| 290 |
+
wrapper,
|
| 291 |
+
cwd=workspace_dir,
|
| 292 |
+
timeout_s=SCRIPT_TIMEOUT_S,
|
| 293 |
+
stdout_limit=INTERNAL_STDOUT_LIMIT,
|
| 294 |
+
stderr_limit=INTERNAL_STDERR_LIMIT,
|
| 295 |
+
)
|
| 296 |
+
payload = next(
|
| 297 |
+
(
|
| 298 |
+
line[len("__RESULT__=") :]
|
| 299 |
+
for line in result.stdout.splitlines()
|
| 300 |
+
if line.startswith("__RESULT__=")
|
| 301 |
+
),
|
| 302 |
+
"",
|
| 303 |
+
)
|
| 304 |
+
try:
|
| 305 |
+
parsed = json.loads(payload) if payload else None
|
| 306 |
+
except json.JSONDecodeError:
|
| 307 |
+
parsed = None
|
| 308 |
+
normalised = normalize_task_2_output_rows(parsed)
|
| 309 |
+
ok = result.returncode == 0 and normalised == list(expected)
|
| 310 |
+
return {
|
| 311 |
+
"passed": ok,
|
| 312 |
+
"timed_out": result.timed_out,
|
| 313 |
+
"stdout": result.stdout[:500],
|
| 314 |
+
"stderr": result.stderr[:500],
|
| 315 |
+
"actual": normalised,
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _run_task_2_hidden_tests(
|
| 320 |
+
workspace_dir: str,
|
| 321 |
+
hidden_cases: tuple[tuple[dict[str, Any], ...], ...],
|
| 322 |
+
hidden_expected: tuple[tuple[dict[str, Any], ...], ...],
|
| 323 |
+
) -> dict[str, Any]:
|
| 324 |
+
wrapper = f"""
|
| 325 |
+
import importlib.util
|
| 326 |
+
import json
|
| 327 |
+
|
| 328 |
+
spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
|
| 329 |
+
module = importlib.util.module_from_spec(spec)
|
| 330 |
+
assert spec.loader is not None
|
| 331 |
+
spec.loader.exec_module(module)
|
| 332 |
+
|
| 333 |
+
cases = {json.dumps([list(batch) for batch in hidden_cases])}
|
| 334 |
+
results = [module.process_data_stream(case) for case in cases]
|
| 335 |
+
print("__RESULT__=" + json.dumps(results))
|
| 336 |
+
"""
|
| 337 |
+
result = run_python_code(
|
| 338 |
+
wrapper,
|
| 339 |
+
cwd=workspace_dir,
|
| 340 |
+
timeout_s=SCRIPT_TIMEOUT_S,
|
| 341 |
+
stdout_limit=INTERNAL_STDOUT_LIMIT,
|
| 342 |
+
stderr_limit=INTERNAL_STDERR_LIMIT,
|
| 343 |
+
)
|
| 344 |
+
payload = next(
|
| 345 |
+
(
|
| 346 |
+
line[len("__RESULT__=") :]
|
| 347 |
+
for line in result.stdout.splitlines()
|
| 348 |
+
if line.startswith("__RESULT__=")
|
| 349 |
+
),
|
| 350 |
+
"",
|
| 351 |
+
)
|
| 352 |
+
try:
|
| 353 |
+
parsed = json.loads(payload) if payload else None
|
| 354 |
+
except json.JSONDecodeError:
|
| 355 |
+
parsed = None
|
| 356 |
+
if not isinstance(parsed, list):
|
| 357 |
+
parsed = []
|
| 358 |
+
|
| 359 |
+
actual_batches = [
|
| 360 |
+
normalize_task_2_output_rows(batch)
|
| 361 |
+
for batch in parsed
|
| 362 |
+
]
|
| 363 |
+
expected = [list(batch) for batch in hidden_expected]
|
| 364 |
+
case_passes = [
|
| 365 |
+
actual == expected_case
|
| 366 |
+
for actual, expected_case in zip(actual_batches, expected, strict=False)
|
| 367 |
+
]
|
| 368 |
+
if len(case_passes) < len(expected):
|
| 369 |
+
case_passes.extend([False] * (len(expected) - len(case_passes)))
|
| 370 |
+
pass_fraction = (
|
| 371 |
+
sum(1 for passed in case_passes if passed) / len(expected)
|
| 372 |
+
if expected
|
| 373 |
+
else 0.0
|
| 374 |
+
)
|
| 375 |
+
return {
|
| 376 |
+
"passed": result.returncode == 0 and len(actual_batches) == len(expected) and all(case_passes),
|
| 377 |
+
"timed_out": result.timed_out,
|
| 378 |
+
"stdout": result.stdout[:500],
|
| 379 |
+
"stderr": result.stderr[:500],
|
| 380 |
+
"actual": actual_batches,
|
| 381 |
+
"case_passes": case_passes,
|
| 382 |
+
"pass_fraction": round(pass_fraction, 4),
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _load_task_3_data(
|
| 387 |
+
workspace_dir: str, expected_rows: list[dict[str, Any]]
|
| 388 |
+
) -> dict[str, Any]:
|
| 389 |
+
report_json = os.path.join(workspace_dir, "report_data.json")
|
| 390 |
+
if not os.path.isfile(report_json):
|
| 391 |
+
return {
|
| 392 |
+
"matches_expected": False,
|
| 393 |
+
"details": {"exists": False, "reason": "report_data.json not found."},
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
try:
|
| 397 |
+
with open(report_json, encoding="utf-8") as f:
|
| 398 |
+
payload = json.load(f)
|
| 399 |
+
except (OSError, json.JSONDecodeError) as exc:
|
| 400 |
+
return {
|
| 401 |
+
"matches_expected": False,
|
| 402 |
+
"details": {"exists": True, "reason": str(exc)},
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
if not isinstance(payload, list):
|
| 406 |
+
return {
|
| 407 |
+
"matches_expected": False,
|
| 408 |
+
"details": {
|
| 409 |
+
"exists": True,
|
| 410 |
+
"reason": "report_data.json must contain a JSON list.",
|
| 411 |
+
},
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
rows = normalize_task_3_rows(payload, require_headcount=True)
|
| 415 |
+
matches_expected = bool(rows) and task_3_data_matches_expected(
|
| 416 |
+
rows,
|
| 417 |
+
expected_rows,
|
| 418 |
+
require_headcount=True,
|
| 419 |
+
)
|
| 420 |
+
semantic_fraction = task_3_semantic_match_fraction_rows(rows, expected_rows)
|
| 421 |
+
return {
|
| 422 |
+
"matches_expected": matches_expected,
|
| 423 |
+
"details": {
|
| 424 |
+
"exists": True,
|
| 425 |
+
"rows_valid": bool(rows),
|
| 426 |
+
"rows_match_expected": matches_expected,
|
| 427 |
+
"semantic_fraction": round(semantic_fraction, 4),
|
| 428 |
+
},
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _run_task_3_formatter(
|
| 433 |
+
workspace_dir: str,
|
| 434 |
+
expected_rows: list[dict[str, Any]],
|
| 435 |
+
target_date: str,
|
| 436 |
+
) -> dict[str, Any]:
|
| 437 |
+
script = os.path.join(workspace_dir, "format_report.py")
|
| 438 |
+
if not os.path.isfile(script):
|
| 439 |
+
return {
|
| 440 |
+
"runs": False,
|
| 441 |
+
"matches_expected": False,
|
| 442 |
+
"details": {"reason": "format_report.py not found."},
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
result = run_python_script(
|
| 447 |
+
"format_report.py",
|
| 448 |
+
cwd=workspace_dir,
|
| 449 |
+
args=["report_data.json"],
|
| 450 |
+
timeout_s=SCRIPT_TIMEOUT_S,
|
| 451 |
+
stdout_limit=INTERNAL_STDOUT_LIMIT,
|
| 452 |
+
stderr_limit=INTERNAL_STDERR_LIMIT,
|
| 453 |
+
)
|
| 454 |
+
except Exception as exc:
|
| 455 |
+
return {
|
| 456 |
+
"runs": False,
|
| 457 |
+
"matches_expected": False,
|
| 458 |
+
"details": {"reason": str(exc)},
|
| 459 |
+
}
|
| 460 |
+
if result.timed_out:
|
| 461 |
+
return {
|
| 462 |
+
"runs": False,
|
| 463 |
+
"matches_expected": False,
|
| 464 |
+
"details": {"reason": "Formatter timed out."},
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
stdout = (result.stdout or "").strip()
|
| 468 |
+
matches_expected = result.returncode == 0 and report_matches_expected(
|
| 469 |
+
stdout,
|
| 470 |
+
expected_rows,
|
| 471 |
+
target_date,
|
| 472 |
+
)
|
| 473 |
+
return {
|
| 474 |
+
"runs": result.returncode == 0,
|
| 475 |
+
"matches_expected": matches_expected,
|
| 476 |
+
"details": {
|
| 477 |
+
"exit_code": result.returncode,
|
| 478 |
+
"stdout": stdout[:500],
|
| 479 |
+
"stderr": (result.stderr or "")[:500],
|
| 480 |
+
"semantic_fraction": round(
|
| 481 |
+
task_3_semantic_match_fraction_text(stdout, expected_rows, target_date),
|
| 482 |
+
4,
|
| 483 |
+
),
|
| 484 |
+
},
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _score_task_3_email(
|
| 489 |
+
env: DataOpsEnvironment, expected_report: str
|
| 490 |
+
) -> dict[str, Any]:
|
| 491 |
+
scenario = env.scenario.task_3
|
| 492 |
+
assert scenario is not None
|
| 493 |
+
evidence = env.evidence.get("task_3", {})
|
| 494 |
+
outbox = env.email_outbox
|
| 495 |
+
if not outbox:
|
| 496 |
+
return {
|
| 497 |
+
"score": 0.0,
|
| 498 |
+
"passed": False,
|
| 499 |
+
"details": {"reason": "No email sent."},
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
email = outbox[-1]
|
| 503 |
+
recipient_ok = email.get("to_email") == scenario.recipient
|
| 504 |
+
subject_ok = email.get("subject") == scenario.subject
|
| 505 |
+
body = str(email.get("body", "")).strip()
|
| 506 |
+
body_ok = body == expected_report.strip()
|
| 507 |
+
proven = bool(evidence.get("email_matches_formatter_output")) and len(outbox) == 1
|
| 508 |
+
|
| 509 |
+
score = 0.0
|
| 510 |
+
if recipient_ok:
|
| 511 |
+
score += 0.05
|
| 512 |
+
if subject_ok:
|
| 513 |
+
score += 0.05
|
| 514 |
+
if body_ok and proven:
|
| 515 |
+
score += 0.25
|
| 516 |
+
|
| 517 |
+
return {
|
| 518 |
+
"score": score,
|
| 519 |
+
"passed": score == 0.35,
|
| 520 |
+
"details": {
|
| 521 |
+
"emails_sent": len(outbox),
|
| 522 |
+
"recipient_ok": recipient_ok,
|
| 523 |
+
"subject_ok": subject_ok,
|
| 524 |
+
"body_ok": body_ok,
|
| 525 |
+
"proven": proven,
|
| 526 |
+
"semantic_fraction": round(
|
| 527 |
+
task_3_semantic_match_fraction_text(
|
| 528 |
+
body,
|
| 529 |
+
list(scenario.expected_rows),
|
| 530 |
+
scenario.target_date,
|
| 531 |
+
),
|
| 532 |
+
4,
|
| 533 |
+
),
|
| 534 |
+
},
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _current_transactions_rows(db_path: str) -> list[dict[str, Any]]:
|
| 539 |
+
with sqlite3.connect(db_path) as conn:
|
| 540 |
+
conn.row_factory = sqlite3.Row
|
| 541 |
+
table_exists = conn.execute(
|
| 542 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
|
| 543 |
+
).fetchone()
|
| 544 |
+
if not table_exists:
|
| 545 |
+
return []
|
| 546 |
+
rows = conn.execute(
|
| 547 |
+
"SELECT id, user_id, amount, status FROM transactions ORDER BY id"
|
| 548 |
+
).fetchall()
|
| 549 |
+
return [
|
| 550 |
+
{
|
| 551 |
+
"id": int(row["id"]),
|
| 552 |
+
"user_id": int(row["user_id"]),
|
| 553 |
+
"amount": None if row["amount"] is None else round(float(row["amount"]), 2),
|
| 554 |
+
"status": str(row["status"]),
|
| 555 |
+
}
|
| 556 |
+
for row in rows
|
| 557 |
+
]
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
starlette>=0.46.0,<0.52.0
|
| 4 |
+
uvicorn[standard]>=0.34.0
|
| 5 |
+
pydantic>=2.10.0
|
| 6 |
+
pyyaml>=6.0.2
|
| 7 |
+
openai>=1.60.0
|
| 8 |
+
requests>=2.32.0
|
| 9 |
+
wsproto>=1.3.2
|
| 10 |
+
python-dotenv>=1.0.0
|
server/safe_exec.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import signal
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
DEFAULT_ADDRESS_SPACE_BYTES = 512 * 1024 * 1024
|
| 13 |
+
DEFAULT_FILE_BYTES = 2 * 1024 * 1024
|
| 14 |
+
DEFAULT_OPEN_FILES = 64
|
| 15 |
+
DEFAULT_PROCESSES = 32
|
| 16 |
+
|
| 17 |
+
_RUNNER_BOOTSTRAP = r"""
|
| 18 |
+
import json
|
| 19 |
+
import runpy
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import resource
|
| 24 |
+
except ImportError: # pragma: no cover
|
| 25 |
+
resource = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _set_limit(name, value):
|
| 29 |
+
if resource is None or not hasattr(resource, name):
|
| 30 |
+
return
|
| 31 |
+
limit = int(value)
|
| 32 |
+
try:
|
| 33 |
+
_, current_hard = resource.getrlimit(getattr(resource, name))
|
| 34 |
+
soft = min(limit, current_hard) if current_hard >= 0 else limit
|
| 35 |
+
resource.setrlimit(getattr(resource, name), (soft, current_hard))
|
| 36 |
+
except (OSError, ValueError):
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
config = json.loads(sys.argv[1])
|
| 41 |
+
_set_limit("RLIMIT_CORE", 0)
|
| 42 |
+
_set_limit("RLIMIT_CPU", config["cpu_seconds"])
|
| 43 |
+
_set_limit("RLIMIT_FSIZE", config["file_bytes"])
|
| 44 |
+
_set_limit("RLIMIT_NOFILE", config["open_files"])
|
| 45 |
+
_set_limit("RLIMIT_AS", config["address_space_bytes"])
|
| 46 |
+
_set_limit("RLIMIT_NPROC", config["processes"])
|
| 47 |
+
|
| 48 |
+
mode = config["mode"]
|
| 49 |
+
if mode == "script":
|
| 50 |
+
script = sys.argv[2]
|
| 51 |
+
sys.argv = sys.argv[2:]
|
| 52 |
+
runpy.run_path(script, run_name="__main__")
|
| 53 |
+
elif mode == "code":
|
| 54 |
+
sys.argv = ["-c"]
|
| 55 |
+
exec(config["code"], {"__name__": "__main__"})
|
| 56 |
+
else: # pragma: no cover
|
| 57 |
+
raise SystemExit(f"Unsupported execution mode: {mode}")
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass(frozen=True)
|
| 62 |
+
class PythonRunResult:
|
| 63 |
+
returncode: int
|
| 64 |
+
stdout: str
|
| 65 |
+
stderr: str
|
| 66 |
+
timed_out: bool = False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _safe_env(workspace_dir: str) -> dict[str, str]:
|
| 70 |
+
return {
|
| 71 |
+
"HOME": workspace_dir,
|
| 72 |
+
"TMPDIR": workspace_dir,
|
| 73 |
+
"LANG": "C.UTF-8",
|
| 74 |
+
"LC_ALL": "C.UTF-8",
|
| 75 |
+
"PATH": "",
|
| 76 |
+
"PYTHONDONTWRITEBYTECODE": "1",
|
| 77 |
+
"PYTHONHASHSEED": "0",
|
| 78 |
+
"PYTHONIOENCODING": "utf-8",
|
| 79 |
+
"PYTHONNOUSERSITE": "1",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _limit_config(timeout_s: float) -> dict[str, int]:
|
| 84 |
+
return {
|
| 85 |
+
"cpu_seconds": max(1, int(math.ceil(timeout_s)) + 1),
|
| 86 |
+
"file_bytes": DEFAULT_FILE_BYTES,
|
| 87 |
+
"open_files": DEFAULT_OPEN_FILES,
|
| 88 |
+
"address_space_bytes": DEFAULT_ADDRESS_SPACE_BYTES,
|
| 89 |
+
"processes": DEFAULT_PROCESSES,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _read_limited_text(handle, limit: int) -> str:
|
| 94 |
+
handle.seek(0)
|
| 95 |
+
data = handle.read(limit + 1)
|
| 96 |
+
if isinstance(data, bytes):
|
| 97 |
+
return data.decode("utf-8", errors="replace")[:limit]
|
| 98 |
+
return str(data)[:limit]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _terminate_process(proc: subprocess.Popen[bytes]) -> None:
|
| 102 |
+
if proc.poll() is not None:
|
| 103 |
+
return
|
| 104 |
+
if os.name != "nt":
|
| 105 |
+
try:
|
| 106 |
+
os.killpg(proc.pid, signal.SIGKILL)
|
| 107 |
+
return
|
| 108 |
+
except ProcessLookupError:
|
| 109 |
+
return
|
| 110 |
+
proc.kill()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _run_python_command(
|
| 114 |
+
config: dict[str, object],
|
| 115 |
+
*,
|
| 116 |
+
cwd: str,
|
| 117 |
+
argv: list[str],
|
| 118 |
+
timeout_s: float,
|
| 119 |
+
stdout_limit: int,
|
| 120 |
+
stderr_limit: int,
|
| 121 |
+
) -> PythonRunResult:
|
| 122 |
+
command = [
|
| 123 |
+
sys.executable,
|
| 124 |
+
"-I",
|
| 125 |
+
"-B",
|
| 126 |
+
"-c",
|
| 127 |
+
_RUNNER_BOOTSTRAP,
|
| 128 |
+
json.dumps(config, ensure_ascii=True),
|
| 129 |
+
*argv,
|
| 130 |
+
]
|
| 131 |
+
start_new_session = os.name != "nt"
|
| 132 |
+
|
| 133 |
+
with tempfile.TemporaryFile() as stdout_file, tempfile.TemporaryFile() as stderr_file:
|
| 134 |
+
proc = subprocess.Popen(
|
| 135 |
+
command,
|
| 136 |
+
cwd=cwd,
|
| 137 |
+
env=_safe_env(cwd),
|
| 138 |
+
stdin=subprocess.DEVNULL,
|
| 139 |
+
stdout=stdout_file,
|
| 140 |
+
stderr=stderr_file,
|
| 141 |
+
start_new_session=start_new_session,
|
| 142 |
+
)
|
| 143 |
+
timed_out = False
|
| 144 |
+
try:
|
| 145 |
+
proc.wait(timeout=timeout_s)
|
| 146 |
+
except subprocess.TimeoutExpired:
|
| 147 |
+
timed_out = True
|
| 148 |
+
_terminate_process(proc)
|
| 149 |
+
proc.wait()
|
| 150 |
+
|
| 151 |
+
return PythonRunResult(
|
| 152 |
+
returncode=proc.returncode if proc.returncode is not None else -1,
|
| 153 |
+
stdout=_read_limited_text(stdout_file, stdout_limit),
|
| 154 |
+
stderr=_read_limited_text(stderr_file, stderr_limit),
|
| 155 |
+
timed_out=timed_out,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def run_python_script(
|
| 160 |
+
script_name: str,
|
| 161 |
+
*,
|
| 162 |
+
cwd: str,
|
| 163 |
+
args: list[str],
|
| 164 |
+
timeout_s: float,
|
| 165 |
+
stdout_limit: int,
|
| 166 |
+
stderr_limit: int,
|
| 167 |
+
) -> PythonRunResult:
|
| 168 |
+
config = {"mode": "script", **_limit_config(timeout_s)}
|
| 169 |
+
return _run_python_command(
|
| 170 |
+
config,
|
| 171 |
+
cwd=cwd,
|
| 172 |
+
argv=[script_name, *args],
|
| 173 |
+
timeout_s=timeout_s,
|
| 174 |
+
stdout_limit=stdout_limit,
|
| 175 |
+
stderr_limit=stderr_limit,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def run_python_code(
|
| 180 |
+
code: str,
|
| 181 |
+
*,
|
| 182 |
+
cwd: str,
|
| 183 |
+
timeout_s: float,
|
| 184 |
+
stdout_limit: int,
|
| 185 |
+
stderr_limit: int,
|
| 186 |
+
) -> PythonRunResult:
|
| 187 |
+
config = {"mode": "code", "code": code, **_limit_config(timeout_s)}
|
| 188 |
+
return _run_python_command(
|
| 189 |
+
config,
|
| 190 |
+
cwd=cwd,
|
| 191 |
+
argv=[],
|
| 192 |
+
timeout_s=timeout_s,
|
| 193 |
+
stdout_limit=stdout_limit,
|
| 194 |
+
stderr_limit=stderr_limit,
|
| 195 |
+
)
|
server/session_manager.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import threading
|
| 5 |
+
import uuid
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from server.dataops_env_environment import DataOpsEnvironment
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class SessionRecord:
|
| 14 |
+
env: DataOpsEnvironment
|
| 15 |
+
last_access_at: float
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EnvironmentSessionManager:
|
| 19 |
+
"""Small in-memory session store for isolated environment instances."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
*,
|
| 24 |
+
max_sessions: int = 128,
|
| 25 |
+
session_timeout_s: float = 1800.0,
|
| 26 |
+
) -> None:
|
| 27 |
+
self._lock = threading.Lock()
|
| 28 |
+
self._sessions: dict[str, SessionRecord] = {}
|
| 29 |
+
self._max_sessions = max(1, max_sessions)
|
| 30 |
+
self._session_timeout_s = max(1.0, session_timeout_s)
|
| 31 |
+
|
| 32 |
+
def reset_session(
|
| 33 |
+
self,
|
| 34 |
+
*,
|
| 35 |
+
task_id: str,
|
| 36 |
+
seed: Optional[int],
|
| 37 |
+
episode_id: Optional[str],
|
| 38 |
+
session_id: Optional[str],
|
| 39 |
+
) -> tuple[str, DataOpsEnvironment, object]:
|
| 40 |
+
now = time.monotonic()
|
| 41 |
+
to_close: list[DataOpsEnvironment] = []
|
| 42 |
+
|
| 43 |
+
with self._lock:
|
| 44 |
+
to_close.extend(self._collect_expired_envs_locked(now))
|
| 45 |
+
|
| 46 |
+
record = self._sessions.get(session_id) if session_id else None
|
| 47 |
+
if record is None:
|
| 48 |
+
resolved_session_id = str(uuid.uuid4())
|
| 49 |
+
to_close.extend(self._evict_if_full_locked(now))
|
| 50 |
+
env = DataOpsEnvironment()
|
| 51 |
+
self._sessions[resolved_session_id] = SessionRecord(
|
| 52 |
+
env=env,
|
| 53 |
+
last_access_at=now,
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
resolved_session_id = session_id or str(uuid.uuid4())
|
| 57 |
+
record.last_access_at = now
|
| 58 |
+
env = record.env
|
| 59 |
+
|
| 60 |
+
self._close_envs(to_close)
|
| 61 |
+
obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id)
|
| 62 |
+
return resolved_session_id, env, obs
|
| 63 |
+
|
| 64 |
+
def get_session(
|
| 65 |
+
self, session_id: Optional[str]
|
| 66 |
+
) -> tuple[Optional[str], Optional[DataOpsEnvironment]]:
|
| 67 |
+
now = time.monotonic()
|
| 68 |
+
to_close: list[DataOpsEnvironment] = []
|
| 69 |
+
|
| 70 |
+
with self._lock:
|
| 71 |
+
to_close.extend(self._collect_expired_envs_locked(now))
|
| 72 |
+
|
| 73 |
+
if session_id:
|
| 74 |
+
record = self._sessions.get(session_id)
|
| 75 |
+
if record is not None:
|
| 76 |
+
record.last_access_at = now
|
| 77 |
+
env = record.env
|
| 78 |
+
else:
|
| 79 |
+
env = None
|
| 80 |
+
result = (session_id, env)
|
| 81 |
+
else:
|
| 82 |
+
result = (None, None)
|
| 83 |
+
|
| 84 |
+
self._close_envs(to_close)
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
def close_all(self) -> None:
|
| 88 |
+
with self._lock:
|
| 89 |
+
records = list(self._sessions.values())
|
| 90 |
+
self._sessions.clear()
|
| 91 |
+
|
| 92 |
+
self._close_envs([record.env for record in records])
|
| 93 |
+
|
| 94 |
+
def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]:
|
| 95 |
+
expired_ids = [
|
| 96 |
+
session_id
|
| 97 |
+
for session_id, record in self._sessions.items()
|
| 98 |
+
if now - record.last_access_at > self._session_timeout_s
|
| 99 |
+
]
|
| 100 |
+
return self._remove_sessions_locked(expired_ids)
|
| 101 |
+
|
| 102 |
+
def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]:
|
| 103 |
+
if len(self._sessions) < self._max_sessions:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
oldest_session_id = min(
|
| 107 |
+
self._sessions,
|
| 108 |
+
key=lambda session_id: self._sessions[session_id].last_access_at,
|
| 109 |
+
)
|
| 110 |
+
return self._remove_sessions_locked([oldest_session_id])
|
| 111 |
+
|
| 112 |
+
def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]:
|
| 113 |
+
removed: list[DataOpsEnvironment] = []
|
| 114 |
+
for session_id in session_ids:
|
| 115 |
+
record = self._sessions.pop(session_id, None)
|
| 116 |
+
if record is not None:
|
| 117 |
+
removed.append(record.env)
|
| 118 |
+
return removed
|
| 119 |
+
|
| 120 |
+
def _close_envs(self, envs: list[DataOpsEnvironment]) -> None:
|
| 121 |
+
for env in envs:
|
| 122 |
+
env.close()
|
| 123 |
+
|
| 124 |
+
def __del__(self) -> None:
|
| 125 |
+
try:
|
| 126 |
+
self.close_all()
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
server/task_specs.py
ADDED
|
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Seeded task metadata and deterministic scenario builders for DataOpsEnv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import textwrap
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from datetime import date, timedelta
|
| 10 |
+
from typing import Any, Iterable
|
| 11 |
+
|
| 12 |
+
TASK_IDS = [
|
| 13 |
+
"task_1_easy_anomaly",
|
| 14 |
+
"task_2_medium_syntax",
|
| 15 |
+
"task_3_hard_e2e",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class SQLPolicy:
|
| 21 |
+
allowed_commands: frozenset[str]
|
| 22 |
+
required_table: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class TaskMetadata:
|
| 27 |
+
task_id: str
|
| 28 |
+
name: str
|
| 29 |
+
difficulty: str
|
| 30 |
+
short_description: str
|
| 31 |
+
benchmark_focus: str
|
| 32 |
+
allowed_actions: tuple[str, ...]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class Task1Scenario:
|
| 37 |
+
description: str
|
| 38 |
+
all_rows: tuple[dict[str, Any], ...]
|
| 39 |
+
expected_rows: tuple[dict[str, Any], ...]
|
| 40 |
+
corrupted_row_ids: tuple[int, ...]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class Task2Scenario:
|
| 45 |
+
description: str
|
| 46 |
+
visible_batch: tuple[dict[str, Any], ...]
|
| 47 |
+
visible_expected: tuple[dict[str, Any], ...]
|
| 48 |
+
hidden_cases: tuple[tuple[dict[str, Any], ...], ...]
|
| 49 |
+
hidden_expected: tuple[tuple[dict[str, Any], ...], ...]
|
| 50 |
+
broken_script: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(frozen=True)
|
| 54 |
+
class Task3Scenario:
|
| 55 |
+
description: str
|
| 56 |
+
target_date: str
|
| 57 |
+
recipient: str
|
| 58 |
+
subject: str
|
| 59 |
+
report_title: str
|
| 60 |
+
all_rows: tuple[dict[str, Any], ...]
|
| 61 |
+
expected_rows: tuple[dict[str, Any], ...]
|
| 62 |
+
broken_script: str
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass(frozen=True)
|
| 66 |
+
class TaskScenarioBundle:
|
| 67 |
+
task_id: str
|
| 68 |
+
seed: int
|
| 69 |
+
description: str
|
| 70 |
+
task_1: Task1Scenario | None = None
|
| 71 |
+
task_2: Task2Scenario | None = None
|
| 72 |
+
task_3: Task3Scenario | None = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
TASK_METADATA = {
|
| 76 |
+
"task_1_easy_anomaly": TaskMetadata(
|
| 77 |
+
task_id="task_1_easy_anomaly",
|
| 78 |
+
name="Delete Corrupted Transaction Rows",
|
| 79 |
+
difficulty="easy",
|
| 80 |
+
short_description=(
|
| 81 |
+
"Inspect a transaction table and remove only the seeded rows with NULL amounts while preserving legitimate non-null edge values."
|
| 82 |
+
),
|
| 83 |
+
benchmark_focus="Careful data cleanup without collateral damage.",
|
| 84 |
+
allowed_actions=("ExecuteSQL",),
|
| 85 |
+
),
|
| 86 |
+
"task_2_medium_syntax": TaskMetadata(
|
| 87 |
+
task_id="task_2_medium_syntax",
|
| 88 |
+
name="Repair Seeded Pipeline Script",
|
| 89 |
+
difficulty="medium",
|
| 90 |
+
short_description=(
|
| 91 |
+
"Repair a seeded ETL normalization script and verify it on visible and hidden seeded batches."
|
| 92 |
+
),
|
| 93 |
+
benchmark_focus="Code reading, precise repair, and generalization beyond the demo batch.",
|
| 94 |
+
allowed_actions=("ReadFile", "WriteFile", "RunScript"),
|
| 95 |
+
),
|
| 96 |
+
"task_3_hard_e2e": TaskMetadata(
|
| 97 |
+
task_id="task_3_hard_e2e",
|
| 98 |
+
name="Resolve Revenue Reporting Incident",
|
| 99 |
+
difficulty="hard",
|
| 100 |
+
short_description=(
|
| 101 |
+
"Extract a seeded reporting slice, repair the formatter, and send the exact generated report."
|
| 102 |
+
),
|
| 103 |
+
benchmark_focus="End-to-end data extraction, file repair, and communication with provenance.",
|
| 104 |
+
allowed_actions=("ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail"),
|
| 105 |
+
),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
TASK_DESCRIPTIONS = {
|
| 109 |
+
task_id: metadata.short_description for task_id, metadata in TASK_METADATA.items()
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
TASK_ALLOWED_WRITE_FILES = {
|
| 113 |
+
"task_1_easy_anomaly": frozenset(),
|
| 114 |
+
"task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
|
| 115 |
+
"task_3_hard_e2e": frozenset({"format_report.py", "report_data.json"}),
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
TASK_ALLOWED_RUN_FILES = {
|
| 119 |
+
"task_1_easy_anomaly": frozenset(),
|
| 120 |
+
"task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
|
| 121 |
+
"task_3_hard_e2e": frozenset({"format_report.py"}),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
TASK_EMAIL_ENABLED = frozenset({"task_3_hard_e2e"})
|
| 125 |
+
|
| 126 |
+
TASK_ALLOWED_READ_FILES = {
|
| 127 |
+
"task_1_easy_anomaly": frozenset(),
|
| 128 |
+
"task_2_medium_syntax": frozenset({"broken_pipeline.py"}),
|
| 129 |
+
"task_3_hard_e2e": frozenset({"format_report.py", "report_data.json"}),
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
TASK_SQL_POLICIES = {
|
| 133 |
+
"task_1_easy_anomaly": SQLPolicy(
|
| 134 |
+
allowed_commands=frozenset({"SELECT", "DELETE"}),
|
| 135 |
+
required_table="transactions",
|
| 136 |
+
),
|
| 137 |
+
"task_3_hard_e2e": SQLPolicy(
|
| 138 |
+
allowed_commands=frozenset({"SELECT", "WITH"}),
|
| 139 |
+
required_table="daily_reports",
|
| 140 |
+
),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
_REPORT_RECORD_RE = re.compile(
|
| 144 |
+
r"Department:\s*(?P<department>[^\n]+)\n"
|
| 145 |
+
r"\s*Revenue:\s*\$(?P<revenue>-?\d+(?:\.\d+)?)\n"
|
| 146 |
+
r"\s*Expenses:\s*\$(?P<expenses>-?\d+(?:\.\d+)?)\n"
|
| 147 |
+
r"\s*Net:\s*\$(?P<net>-?\d+(?:\.\d+)?)",
|
| 148 |
+
re.MULTILINE,
|
| 149 |
+
)
|
| 150 |
+
_REPORT_TOTAL_RE = re.compile(r"Total Revenue:\s*\$(?P<total>-?\d+(?:\.\d+)?)")
|
| 151 |
+
|
| 152 |
+
_TASK_1_VALID_STATUSES = ("success", "settled", "approved", "completed")
|
| 153 |
+
_TASK_1_CORRUPTED_STATUSES = ("pending", "retrying", "failed", "queued")
|
| 154 |
+
_TASK_2_READY_STATUS = "ready"
|
| 155 |
+
_TASK_2_NON_READY_STATUSES = ("queued", "hold", "failed")
|
| 156 |
+
_TASK_2_REGIONS = ("us-east", "eu-west", "ap-south", "sa-east")
|
| 157 |
+
_TASK_3_RECIPIENTS = (
|
| 158 |
+
"bhavik@example.com",
|
| 159 |
+
"marta@example.com",
|
| 160 |
+
"ops-lead@example.com",
|
| 161 |
+
"finance-review@example.com",
|
| 162 |
+
)
|
| 163 |
+
_TASK_3_DEPARTMENTS = (
|
| 164 |
+
"Engineering",
|
| 165 |
+
"Sales",
|
| 166 |
+
"Marketing",
|
| 167 |
+
"Operations",
|
| 168 |
+
"Support",
|
| 169 |
+
"Finance",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def task_manifest_entries() -> list[dict[str, Any]]:
|
| 174 |
+
return [
|
| 175 |
+
{
|
| 176 |
+
"id": metadata.task_id,
|
| 177 |
+
"name": metadata.name,
|
| 178 |
+
"difficulty": metadata.difficulty,
|
| 179 |
+
"description": metadata.short_description,
|
| 180 |
+
"benchmark_focus": metadata.benchmark_focus,
|
| 181 |
+
"allowed_actions": list(metadata.allowed_actions),
|
| 182 |
+
}
|
| 183 |
+
for metadata in TASK_METADATA.values()
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def build_task_scenario(task_id: str, seed: int | None = None) -> TaskScenarioBundle:
|
| 188 |
+
resolved_seed = 0 if seed is None else int(seed)
|
| 189 |
+
|
| 190 |
+
if task_id == "task_1_easy_anomaly":
|
| 191 |
+
task = _build_task_1_scenario(resolved_seed)
|
| 192 |
+
return TaskScenarioBundle(
|
| 193 |
+
task_id=task_id,
|
| 194 |
+
seed=resolved_seed,
|
| 195 |
+
description=task.description,
|
| 196 |
+
task_1=task,
|
| 197 |
+
)
|
| 198 |
+
if task_id == "task_2_medium_syntax":
|
| 199 |
+
task = _build_task_2_scenario(resolved_seed)
|
| 200 |
+
return TaskScenarioBundle(
|
| 201 |
+
task_id=task_id,
|
| 202 |
+
seed=resolved_seed,
|
| 203 |
+
description=task.description,
|
| 204 |
+
task_2=task,
|
| 205 |
+
)
|
| 206 |
+
if task_id == "task_3_hard_e2e":
|
| 207 |
+
task = _build_task_3_scenario(resolved_seed)
|
| 208 |
+
return TaskScenarioBundle(
|
| 209 |
+
task_id=task_id,
|
| 210 |
+
seed=resolved_seed,
|
| 211 |
+
description=task.description,
|
| 212 |
+
task_3=task,
|
| 213 |
+
)
|
| 214 |
+
raise KeyError(f"Unknown task_id: {task_id}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def normalize_task_3_rows(
|
| 218 |
+
rows: Iterable[dict[str, Any]], *, require_headcount: bool = False
|
| 219 |
+
) -> list[dict[str, Any]]:
|
| 220 |
+
"""Normalise extracted rows for deterministic comparison."""
|
| 221 |
+
normalised: list[dict[str, Any]] = []
|
| 222 |
+
for row in rows:
|
| 223 |
+
try:
|
| 224 |
+
hc_raw = row.get("headcount")
|
| 225 |
+
if hc_raw is None or hc_raw == "":
|
| 226 |
+
if require_headcount:
|
| 227 |
+
return []
|
| 228 |
+
headcount: int | None = None
|
| 229 |
+
else:
|
| 230 |
+
headcount = int(hc_raw)
|
| 231 |
+
normalised.append(
|
| 232 |
+
{
|
| 233 |
+
"department": str(row["department"]),
|
| 234 |
+
"revenue": round(float(row["revenue"]), 2),
|
| 235 |
+
"expenses": round(float(row["expenses"]), 2),
|
| 236 |
+
"headcount": headcount,
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
except (KeyError, TypeError, ValueError):
|
| 240 |
+
return []
|
| 241 |
+
return sorted(normalised, key=lambda item: item["department"])
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def normalize_task_2_output_rows(rows: Any) -> list[dict[str, Any]]:
|
| 245 |
+
"""Normalise Task 2 ETL output rows while preserving list order for sort checks."""
|
| 246 |
+
if not isinstance(rows, list):
|
| 247 |
+
return []
|
| 248 |
+
|
| 249 |
+
normalised: list[dict[str, Any]] = []
|
| 250 |
+
for row in rows:
|
| 251 |
+
if not isinstance(row, dict):
|
| 252 |
+
return []
|
| 253 |
+
try:
|
| 254 |
+
order_id = str(row["order_id"])
|
| 255 |
+
region = str(row["region"])
|
| 256 |
+
amount_usd = round(float(row["amount_usd"]), 2)
|
| 257 |
+
priority_band = str(row["priority_band"])
|
| 258 |
+
except (KeyError, TypeError, ValueError):
|
| 259 |
+
return []
|
| 260 |
+
|
| 261 |
+
if priority_band not in {"high", "normal"}:
|
| 262 |
+
return []
|
| 263 |
+
|
| 264 |
+
normalised.append(
|
| 265 |
+
{
|
| 266 |
+
"order_id": order_id,
|
| 267 |
+
"region": region,
|
| 268 |
+
"amount_usd": amount_usd,
|
| 269 |
+
"priority_band": priority_band,
|
| 270 |
+
}
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return normalised
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def build_task_2_expected(
|
| 277 |
+
batch: Iterable[dict[str, Any]]
|
| 278 |
+
) -> list[dict[str, Any]]:
|
| 279 |
+
processed: list[dict[str, Any]] = []
|
| 280 |
+
|
| 281 |
+
for record in batch:
|
| 282 |
+
try:
|
| 283 |
+
status = str(record["status"])
|
| 284 |
+
amount_cents = int(record["amount_cents"])
|
| 285 |
+
priority = int(record["priority"])
|
| 286 |
+
amount_usd = round(amount_cents / 100.0, 2)
|
| 287 |
+
if status != _TASK_2_READY_STATUS or amount_cents <= 0:
|
| 288 |
+
continue
|
| 289 |
+
processed.append(
|
| 290 |
+
{
|
| 291 |
+
"order_id": str(record["order_id"]),
|
| 292 |
+
"region": str(record["region"]),
|
| 293 |
+
"amount_usd": amount_usd,
|
| 294 |
+
"priority_band": "high"
|
| 295 |
+
if priority >= 8 or amount_usd >= 500.0
|
| 296 |
+
else "normal",
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
except (KeyError, TypeError, ValueError):
|
| 300 |
+
return []
|
| 301 |
+
|
| 302 |
+
processed.sort(key=lambda item: (-item["amount_usd"], item["order_id"]))
|
| 303 |
+
return processed
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def task_3_data_matches_expected(
|
| 307 |
+
rows: list[dict[str, Any]],
|
| 308 |
+
expected_rows: Iterable[dict[str, Any]],
|
| 309 |
+
*,
|
| 310 |
+
require_headcount: bool,
|
| 311 |
+
) -> bool:
|
| 312 |
+
expected = normalize_task_3_rows(expected_rows, require_headcount=require_headcount)
|
| 313 |
+
return rows == expected
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def task_3_headcount_fully_matches(
|
| 317 |
+
rows: list[dict[str, Any]], expected_rows: Iterable[dict[str, Any]]
|
| 318 |
+
) -> bool:
|
| 319 |
+
expected = normalize_task_3_rows(expected_rows, require_headcount=True)
|
| 320 |
+
return rows == expected
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def build_task_3_report(rows: Iterable[dict[str, Any]], target_date: str) -> str:
|
| 324 |
+
report_rows = normalize_task_3_rows(rows, require_headcount=True)
|
| 325 |
+
lines = [f"=== Daily Revenue Report ({target_date}) ===", ""]
|
| 326 |
+
total_revenue = 0.0
|
| 327 |
+
|
| 328 |
+
for row in report_rows:
|
| 329 |
+
revenue = float(row["revenue"])
|
| 330 |
+
expenses = float(row["expenses"])
|
| 331 |
+
net = revenue - expenses
|
| 332 |
+
lines.append(f"Department: {row['department']}")
|
| 333 |
+
lines.append(f" Revenue: ${revenue:.2f}")
|
| 334 |
+
lines.append(f" Expenses: ${expenses:.2f}")
|
| 335 |
+
lines.append(f" Net: ${net:.2f}")
|
| 336 |
+
lines.append("")
|
| 337 |
+
total_revenue += revenue
|
| 338 |
+
|
| 339 |
+
lines.append(f"Total Revenue: ${total_revenue:.2f}")
|
| 340 |
+
lines.append("=== End of Report ===")
|
| 341 |
+
return "\n".join(lines)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def extract_task_3_report_block(text: str, target_date: str) -> str | None:
|
| 345 |
+
raw = text.replace("\r\n", "\n")
|
| 346 |
+
start_marker = f"=== Daily Revenue Report ({target_date}) ==="
|
| 347 |
+
start = raw.find(start_marker)
|
| 348 |
+
end_marker = "=== End of Report ==="
|
| 349 |
+
end = raw.find(end_marker)
|
| 350 |
+
if start == -1 or end == -1 or end < start:
|
| 351 |
+
return None
|
| 352 |
+
return raw[start : end + len(end_marker)].strip()
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def parse_task_3_report(text: str, target_date: str) -> dict[str, Any] | None:
|
| 356 |
+
block = extract_task_3_report_block(text, target_date)
|
| 357 |
+
if block is None:
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
records: list[dict[str, Any]] = []
|
| 361 |
+
for match in _REPORT_RECORD_RE.finditer(block):
|
| 362 |
+
revenue = round(float(match.group("revenue")), 2)
|
| 363 |
+
expenses = round(float(match.group("expenses")), 2)
|
| 364 |
+
net = round(float(match.group("net")), 2)
|
| 365 |
+
records.append(
|
| 366 |
+
{
|
| 367 |
+
"department": match.group("department").strip(),
|
| 368 |
+
"revenue": revenue,
|
| 369 |
+
"expenses": expenses,
|
| 370 |
+
"headcount": None,
|
| 371 |
+
"net": net,
|
| 372 |
+
}
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
total_match = _REPORT_TOTAL_RE.search(block)
|
| 376 |
+
if not total_match:
|
| 377 |
+
return None
|
| 378 |
+
|
| 379 |
+
return {
|
| 380 |
+
"records": sorted(records, key=lambda item: item["department"]),
|
| 381 |
+
"total_revenue": round(float(total_match.group("total")), 2),
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def report_matches_expected(
|
| 386 |
+
text: str, expected_rows: Iterable[dict[str, Any]], target_date: str
|
| 387 |
+
) -> bool:
|
| 388 |
+
parsed = parse_task_3_report(text, target_date)
|
| 389 |
+
if parsed is None:
|
| 390 |
+
return False
|
| 391 |
+
|
| 392 |
+
expected = normalize_task_3_rows(expected_rows, require_headcount=True)
|
| 393 |
+
expected_records = [
|
| 394 |
+
{
|
| 395 |
+
"department": row["department"],
|
| 396 |
+
"revenue": row["revenue"],
|
| 397 |
+
"expenses": row["expenses"],
|
| 398 |
+
"headcount": None,
|
| 399 |
+
"net": round(float(row["revenue"]) - float(row["expenses"]), 2),
|
| 400 |
+
}
|
| 401 |
+
for row in expected
|
| 402 |
+
]
|
| 403 |
+
expected_total = round(sum(float(row["revenue"]) for row in expected), 2)
|
| 404 |
+
return (
|
| 405 |
+
parsed["records"] == expected_records
|
| 406 |
+
and parsed["total_revenue"] == expected_total
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def task_3_semantic_match_fraction_rows(
|
| 411 |
+
rows: list[dict[str, Any]], expected_rows: Iterable[dict[str, Any]]
|
| 412 |
+
) -> float:
|
| 413 |
+
if not rows:
|
| 414 |
+
return 0.0
|
| 415 |
+
expected = normalize_task_3_rows(expected_rows, require_headcount=False)
|
| 416 |
+
exp_by_dept = {row["department"]: row for row in expected}
|
| 417 |
+
matched = 0
|
| 418 |
+
for row in rows:
|
| 419 |
+
department = row.get("department")
|
| 420 |
+
if department not in exp_by_dept:
|
| 421 |
+
continue
|
| 422 |
+
expected_row = exp_by_dept[department]
|
| 423 |
+
if (
|
| 424 |
+
row.get("revenue") == expected_row["revenue"]
|
| 425 |
+
and row.get("expenses") == expected_row["expenses"]
|
| 426 |
+
):
|
| 427 |
+
matched += 1
|
| 428 |
+
return matched / len(expected) if expected else 0.0
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def task_3_semantic_match_fraction_parsed(
|
| 432 |
+
parsed: dict[str, Any] | None, expected_rows: Iterable[dict[str, Any]]
|
| 433 |
+
) -> float:
|
| 434 |
+
if not parsed or not parsed.get("records"):
|
| 435 |
+
return 0.0
|
| 436 |
+
expected = normalize_task_3_rows(expected_rows, require_headcount=False)
|
| 437 |
+
exp_by_dept = {row["department"]: row for row in expected}
|
| 438 |
+
matched = 0
|
| 439 |
+
for record in parsed["records"]:
|
| 440 |
+
department = record.get("department")
|
| 441 |
+
if department not in exp_by_dept:
|
| 442 |
+
continue
|
| 443 |
+
expected_row = exp_by_dept[department]
|
| 444 |
+
if (
|
| 445 |
+
record.get("revenue") == expected_row["revenue"]
|
| 446 |
+
and record.get("expenses") == expected_row["expenses"]
|
| 447 |
+
):
|
| 448 |
+
matched += 1
|
| 449 |
+
return matched / len(expected) if expected else 0.0
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def task_3_semantic_match_fraction_text(
|
| 453 |
+
text: str, expected_rows: Iterable[dict[str, Any]], target_date: str
|
| 454 |
+
) -> float:
|
| 455 |
+
return task_3_semantic_match_fraction_parsed(
|
| 456 |
+
parse_task_3_report(text, target_date), expected_rows
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _build_task_1_scenario(seed: int) -> Task1Scenario:
|
| 461 |
+
rng = random.Random(f"task-1:{seed}")
|
| 462 |
+
valid_count = 3 + rng.randrange(3)
|
| 463 |
+
corrupted_count = 2 + rng.randrange(2)
|
| 464 |
+
combined_rows: list[dict[str, Any]] = []
|
| 465 |
+
|
| 466 |
+
valid_templates = []
|
| 467 |
+
for index in range(valid_count):
|
| 468 |
+
valid_templates.append(
|
| 469 |
+
{
|
| 470 |
+
"kind": "valid",
|
| 471 |
+
"user_id": 1000 + seed * 10 + index,
|
| 472 |
+
"amount": round(rng.uniform(75.0, 975.0), 2),
|
| 473 |
+
"status": rng.choice(_TASK_1_VALID_STATUSES),
|
| 474 |
+
}
|
| 475 |
+
)
|
| 476 |
+
if valid_templates:
|
| 477 |
+
valid_templates[0]["amount"] = 0.0
|
| 478 |
+
valid_templates[0]["status"] = "settled"
|
| 479 |
+
if len(valid_templates) > 1:
|
| 480 |
+
valid_templates[1]["amount"] = -round(float(valid_templates[1]["amount"]) / 10.0, 2)
|
| 481 |
+
valid_templates[1]["status"] = "approved"
|
| 482 |
+
corrupted_templates = []
|
| 483 |
+
for index in range(corrupted_count):
|
| 484 |
+
corrupted_templates.append(
|
| 485 |
+
{
|
| 486 |
+
"kind": "corrupted",
|
| 487 |
+
"user_id": 2000 + seed * 10 + index,
|
| 488 |
+
"amount": None,
|
| 489 |
+
"status": rng.choice(_TASK_1_CORRUPTED_STATUSES),
|
| 490 |
+
}
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
templates = valid_templates + corrupted_templates
|
| 494 |
+
rng.shuffle(templates)
|
| 495 |
+
|
| 496 |
+
expected_rows: list[dict[str, Any]] = []
|
| 497 |
+
corrupted_row_ids: list[int] = []
|
| 498 |
+
for row_id, template in enumerate(templates, start=1):
|
| 499 |
+
row = {
|
| 500 |
+
"id": row_id,
|
| 501 |
+
"user_id": int(template["user_id"]),
|
| 502 |
+
"amount": template["amount"],
|
| 503 |
+
"status": str(template["status"]),
|
| 504 |
+
}
|
| 505 |
+
combined_rows.append(row)
|
| 506 |
+
if template["kind"] == "valid":
|
| 507 |
+
expected_rows.append(row)
|
| 508 |
+
else:
|
| 509 |
+
corrupted_row_ids.append(row_id)
|
| 510 |
+
|
| 511 |
+
description = (
|
| 512 |
+
"Find and delete all corrupted records (rows with NULL amounts) from the "
|
| 513 |
+
f"'transactions' table. This seeded episode contains {corrupted_count} corrupted "
|
| 514 |
+
f"rows mixed with {valid_count} valid rows. Only NULL amounts are corrupted; "
|
| 515 |
+
"legitimate zero-value reconciliations and negative refund adjustments may also "
|
| 516 |
+
"appear and must be preserved exactly."
|
| 517 |
+
)
|
| 518 |
+
return Task1Scenario(
|
| 519 |
+
description=description,
|
| 520 |
+
all_rows=tuple(combined_rows),
|
| 521 |
+
expected_rows=tuple(expected_rows),
|
| 522 |
+
corrupted_row_ids=tuple(sorted(corrupted_row_ids)),
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _build_task_2_scenario(seed: int) -> Task2Scenario:
|
| 527 |
+
rng = random.Random(f"task-2:{seed}")
|
| 528 |
+
visible_batch = _sample_task_2_batch(rng, batch_index=0)
|
| 529 |
+
hidden_cases = tuple(
|
| 530 |
+
_sample_task_2_batch(rng, batch_index=index + 1)
|
| 531 |
+
for index in range(6)
|
| 532 |
+
)
|
| 533 |
+
visible_expected = tuple(build_task_2_expected(visible_batch))
|
| 534 |
+
hidden_expected = tuple(
|
| 535 |
+
tuple(build_task_2_expected(batch)) for batch in hidden_cases
|
| 536 |
+
)
|
| 537 |
+
description = (
|
| 538 |
+
"The script 'broken_pipeline.py' prepares downstream billing candidates from "
|
| 539 |
+
"seeded order records. Repair it so it keeps only ready records with positive "
|
| 540 |
+
"amounts, converts cents to USD, flags high priority when priority >= 8 or "
|
| 541 |
+
"amount_usd >= 500.00, and returns rows sorted by amount_usd descending then "
|
| 542 |
+
"order_id ascending. The grader checks the visible demo batch and additional "
|
| 543 |
+
"unseen seeded batches."
|
| 544 |
+
)
|
| 545 |
+
return Task2Scenario(
|
| 546 |
+
description=description,
|
| 547 |
+
visible_batch=visible_batch,
|
| 548 |
+
visible_expected=visible_expected,
|
| 549 |
+
hidden_cases=hidden_cases,
|
| 550 |
+
hidden_expected=hidden_expected,
|
| 551 |
+
broken_script=_render_broken_pipeline_script(visible_batch),
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _build_task_3_scenario(seed: int) -> Task3Scenario:
|
| 556 |
+
rng = random.Random(f"task-3:{seed}")
|
| 557 |
+
base_date = date(2025, 3, 25) + timedelta(days=rng.randrange(0, 7))
|
| 558 |
+
target_date = base_date.isoformat()
|
| 559 |
+
recipient = rng.choice(_TASK_3_RECIPIENTS)
|
| 560 |
+
subject = f"Daily Revenue Report - {target_date}"
|
| 561 |
+
report_title = f"Daily Revenue Report ({target_date})"
|
| 562 |
+
selected_departments = sorted(rng.sample(_TASK_3_DEPARTMENTS, k=4))
|
| 563 |
+
|
| 564 |
+
expected_rows: list[dict[str, Any]] = []
|
| 565 |
+
warehouse_rows: list[dict[str, Any]] = []
|
| 566 |
+
row_id = 1
|
| 567 |
+
for offset in (-2, -1, 0, 1):
|
| 568 |
+
report_date = (base_date + timedelta(days=offset)).isoformat()
|
| 569 |
+
for department in selected_departments:
|
| 570 |
+
if offset == 0:
|
| 571 |
+
revenue = round(rng.uniform(12_000.0, 95_000.0), 2)
|
| 572 |
+
expenses = round(rng.uniform(8_000.0, revenue + 18_000.0), 2)
|
| 573 |
+
headcount = rng.randint(8, 48)
|
| 574 |
+
seeded_row = {
|
| 575 |
+
"department": department,
|
| 576 |
+
"revenue": revenue,
|
| 577 |
+
"expenses": expenses,
|
| 578 |
+
"headcount": headcount,
|
| 579 |
+
}
|
| 580 |
+
expected_rows.append(seeded_row)
|
| 581 |
+
else:
|
| 582 |
+
revenue = round(rng.uniform(9_000.0, 90_000.0), 2)
|
| 583 |
+
expenses = round(rng.uniform(7_000.0, revenue + 14_000.0), 2)
|
| 584 |
+
headcount = rng.randint(8, 48)
|
| 585 |
+
warehouse_rows.append(
|
| 586 |
+
{
|
| 587 |
+
"id": row_id,
|
| 588 |
+
"report_date": report_date,
|
| 589 |
+
"department": department,
|
| 590 |
+
"revenue": revenue,
|
| 591 |
+
"expenses": expenses,
|
| 592 |
+
"headcount": headcount,
|
| 593 |
+
}
|
| 594 |
+
)
|
| 595 |
+
row_id += 1
|
| 596 |
+
|
| 597 |
+
description = (
|
| 598 |
+
f"Extract the daily report for date '{target_date}' from the 'daily_reports' table, "
|
| 599 |
+
"repair the broken 'format_report.py' script, save the exact extracted rows to "
|
| 600 |
+
f"'report_data.json', run the script with that file, and send the generated report "
|
| 601 |
+
f"to '{recipient}' with subject '{subject}'. The grader expects the exact seeded slice, "
|
| 602 |
+
"including headcount."
|
| 603 |
+
)
|
| 604 |
+
return Task3Scenario(
|
| 605 |
+
description=description,
|
| 606 |
+
target_date=target_date,
|
| 607 |
+
recipient=recipient,
|
| 608 |
+
subject=subject,
|
| 609 |
+
report_title=report_title,
|
| 610 |
+
all_rows=tuple(warehouse_rows),
|
| 611 |
+
expected_rows=tuple(
|
| 612 |
+
normalize_task_3_rows(expected_rows, require_headcount=True)
|
| 613 |
+
),
|
| 614 |
+
broken_script=_render_broken_format_report_script(target_date),
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _sample_task_2_batch(
|
| 619 |
+
rng: random.Random, *, batch_index: int
|
| 620 |
+
) -> tuple[dict[str, Any], ...]:
|
| 621 |
+
def make_record(
|
| 622 |
+
suffix: str,
|
| 623 |
+
*,
|
| 624 |
+
status: str,
|
| 625 |
+
amount_cents: int,
|
| 626 |
+
priority: int,
|
| 627 |
+
) -> dict[str, Any]:
|
| 628 |
+
return {
|
| 629 |
+
"order_id": f"ORD-{batch_index:02d}-{suffix}",
|
| 630 |
+
"status": status,
|
| 631 |
+
"amount_cents": amount_cents,
|
| 632 |
+
"priority": priority,
|
| 633 |
+
"region": rng.choice(_TASK_2_REGIONS),
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
records = [
|
| 637 |
+
make_record(
|
| 638 |
+
"normal",
|
| 639 |
+
status=_TASK_2_READY_STATUS,
|
| 640 |
+
amount_cents=rng.randrange(12_125, 28_975, 25),
|
| 641 |
+
priority=rng.randint(2, 6),
|
| 642 |
+
),
|
| 643 |
+
make_record(
|
| 644 |
+
"priority",
|
| 645 |
+
status=_TASK_2_READY_STATUS,
|
| 646 |
+
amount_cents=rng.randrange(13_175, 32_775, 25),
|
| 647 |
+
priority=rng.randint(8, 10),
|
| 648 |
+
),
|
| 649 |
+
make_record(
|
| 650 |
+
"amount",
|
| 651 |
+
status=_TASK_2_READY_STATUS,
|
| 652 |
+
amount_cents=rng.randrange(50_025, 88_975, 25),
|
| 653 |
+
priority=rng.randint(2, 6),
|
| 654 |
+
),
|
| 655 |
+
make_record(
|
| 656 |
+
"queued",
|
| 657 |
+
status=rng.choice(_TASK_2_NON_READY_STATUSES[:2]),
|
| 658 |
+
amount_cents=rng.randrange(18_125, 42_975, 25),
|
| 659 |
+
priority=rng.randint(4, 9),
|
| 660 |
+
),
|
| 661 |
+
make_record(
|
| 662 |
+
"drop",
|
| 663 |
+
status=_TASK_2_READY_STATUS,
|
| 664 |
+
amount_cents=-rng.randrange(125, 2_975, 25),
|
| 665 |
+
priority=rng.randint(8, 10),
|
| 666 |
+
),
|
| 667 |
+
]
|
| 668 |
+
|
| 669 |
+
if batch_index % 2 == 0:
|
| 670 |
+
records.append(
|
| 671 |
+
make_record(
|
| 672 |
+
"hold",
|
| 673 |
+
status=rng.choice(_TASK_2_NON_READY_STATUSES),
|
| 674 |
+
amount_cents=rng.randrange(24_125, 48_975, 25),
|
| 675 |
+
priority=rng.randint(1, 7),
|
| 676 |
+
)
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
rng.shuffle(records)
|
| 680 |
+
return tuple(records)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _render_broken_pipeline_script(
|
| 684 |
+
visible_batch: tuple[dict[str, Any], ...]
|
| 685 |
+
) -> str:
|
| 686 |
+
return textwrap.dedent(
|
| 687 |
+
f'''\
|
| 688 |
+
import json
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def process_data_stream(payloads):
|
| 692 |
+
"""
|
| 693 |
+
Normalize downstream billing candidates.
|
| 694 |
+
|
| 695 |
+
Keep only records whose status is "ready" and whose amount_cents is positive.
|
| 696 |
+
Convert amount_cents to amount_usd rounded to 2 decimals.
|
| 697 |
+
Mark priority_band as "high" when priority >= 8 or amount_usd >= 500.00.
|
| 698 |
+
Return rows sorted by amount_usd descending, then order_id ascending.
|
| 699 |
+
"""
|
| 700 |
+
processed_records = []
|
| 701 |
+
|
| 702 |
+
for payload in payloads:
|
| 703 |
+
if payload["status"] == "failed" or payload["amount_cents"] <= 0:
|
| 704 |
+
continue
|
| 705 |
+
|
| 706 |
+
amount_usd = round(payload["amount_cents"] // 100, 2)
|
| 707 |
+
priority_band = (
|
| 708 |
+
"high"
|
| 709 |
+
if payload["priority"] >= 8 and amount_usd >= 500.0
|
| 710 |
+
else "normal"
|
| 711 |
+
)
|
| 712 |
+
processed_records.append(
|
| 713 |
+
{{
|
| 714 |
+
"order_id": payload["order_id"],
|
| 715 |
+
"region": payload["region"],
|
| 716 |
+
"amount_usd": amount_usd,
|
| 717 |
+
"priority_band": priority_band,
|
| 718 |
+
}}
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
processed_records.sort(key=lambda item: (item["amount_usd"], item["order_id"]))
|
| 722 |
+
return processed_records
|
| 723 |
+
|
| 724 |
+
if __name__ == "__main__":
|
| 725 |
+
mock_batch = {list(visible_batch)!r}
|
| 726 |
+
print(json.dumps(process_data_stream(mock_batch), indent=2, sort_keys=True))
|
| 727 |
+
'''
|
| 728 |
+
).lstrip()
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def _render_broken_format_report_script(target_date: str) -> str:
|
| 732 |
+
title = f"=== Daily Revenue Report ({target_date}) ==="
|
| 733 |
+
return textwrap.dedent(
|
| 734 |
+
f'''\
|
| 735 |
+
import json
|
| 736 |
+
import sys
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def format_report(input_path):
|
| 740 |
+
"""Reads extracted data from JSON and produces a formatted stakeholder report."""
|
| 741 |
+
with open(input_path, encoding="utf-8") as f:
|
| 742 |
+
records = json.load(f)
|
| 743 |
+
|
| 744 |
+
lines = ["{title}", ""]
|
| 745 |
+
total_revenue = 0
|
| 746 |
+
|
| 747 |
+
for rec in records:
|
| 748 |
+
dept = rec["department"]
|
| 749 |
+
rev = int(rec["revenue"]) # BUG 1: int() truncates decimal precision
|
| 750 |
+
exp = rec["expenses"]
|
| 751 |
+
net = rev - exp
|
| 752 |
+
lines.append(f"Department: {{dept}}")
|
| 753 |
+
lines.append(f" Revenue: ${{rev}}")
|
| 754 |
+
lines.append(f" Expenses: ${{exp:.2f}}")
|
| 755 |
+
lines.append(f" Net: ${{net:.2f}}")
|
| 756 |
+
lines.append("")
|
| 757 |
+
total_revenue += rev
|
| 758 |
+
|
| 759 |
+
lines.append(f"Total Revenue: ${{total_revenue}}")
|
| 760 |
+
lines.append("=== End of Report ===")
|
| 761 |
+
|
| 762 |
+
output = "\\n".join(lines)
|
| 763 |
+
print(output)
|
| 764 |
+
return output
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
if __name__ == "__main__":
|
| 768 |
+
if len(sys.argv) < 2:
|
| 769 |
+
print("Usage: python format_report.py <input.json>", file=sys.stderr)
|
| 770 |
+
sys.exit(1)
|
| 771 |
+
format_report(sys.argv[0]) # BUG 2: should be sys.argv[1]
|
| 772 |
+
'''
|
| 773 |
+
).lstrip()
|
tests/test_grading.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""High-signal regression tests for seeded grading and public API shape."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import shutil
|
| 7 |
+
import sqlite3
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from fastapi.testclient import TestClient
|
| 12 |
+
|
| 13 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(_ROOT))
|
| 16 |
+
|
| 17 |
+
from models import DataOpsAction # noqa: E402
|
| 18 |
+
from server.app import app # noqa: E402
|
| 19 |
+
from server.dataops_env_environment import DataOpsEnvironment # noqa: E402
|
| 20 |
+
from server.grading import evaluate_task # noqa: E402
|
| 21 |
+
from server.task_specs import build_task_3_report # noqa: E402
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _fixed_pipeline_script(visible_batch: list[dict[str, object]]) -> str:
|
| 25 |
+
return f'''\
|
| 26 |
+
import json
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_data_stream(payloads):
|
| 30 |
+
processed_records = []
|
| 31 |
+
for payload in payloads:
|
| 32 |
+
if payload["status"] != "ready" or int(payload["amount_cents"]) <= 0:
|
| 33 |
+
continue
|
| 34 |
+
amount_usd = round(int(payload["amount_cents"]) / 100.0, 2)
|
| 35 |
+
priority_band = (
|
| 36 |
+
"high"
|
| 37 |
+
if int(payload["priority"]) >= 8 or amount_usd >= 500.0
|
| 38 |
+
else "normal"
|
| 39 |
+
)
|
| 40 |
+
processed_records.append(
|
| 41 |
+
{{
|
| 42 |
+
"order_id": payload["order_id"],
|
| 43 |
+
"region": payload["region"],
|
| 44 |
+
"amount_usd": amount_usd,
|
| 45 |
+
"priority_band": priority_band,
|
| 46 |
+
}}
|
| 47 |
+
)
|
| 48 |
+
processed_records.sort(key=lambda item: (-item["amount_usd"], item["order_id"]))
|
| 49 |
+
return processed_records
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
mock_batch = {visible_batch!r}
|
| 54 |
+
print(json.dumps(process_data_stream(mock_batch), indent=2, sort_keys=True))
|
| 55 |
+
'''
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _visible_only_pipeline_stub(
|
| 59 |
+
visible_batch: list[dict[str, object]],
|
| 60 |
+
visible_expected: list[dict[str, object]],
|
| 61 |
+
) -> str:
|
| 62 |
+
return f'''\
|
| 63 |
+
import json
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def process_data_stream(payloads):
|
| 67 |
+
visible = {visible_batch!r}
|
| 68 |
+
if payloads == visible:
|
| 69 |
+
return {visible_expected!r}
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
print(json.dumps({visible_expected!r}, indent=2, sort_keys=True))
|
| 75 |
+
'''
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _fixed_format_script(target_date: str) -> str:
|
| 79 |
+
return f'''\
|
| 80 |
+
import json
|
| 81 |
+
import sys
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def format_report(input_path):
|
| 85 |
+
with open(input_path, encoding="utf-8") as f:
|
| 86 |
+
records = json.load(f)
|
| 87 |
+
lines = ["=== Daily Revenue Report ({target_date}) ===", ""]
|
| 88 |
+
total_revenue = 0.0
|
| 89 |
+
for rec in records:
|
| 90 |
+
dept = rec["department"]
|
| 91 |
+
rev = float(rec["revenue"])
|
| 92 |
+
exp = float(rec["expenses"])
|
| 93 |
+
net = rev - exp
|
| 94 |
+
lines.append(f"Department: {{dept}}")
|
| 95 |
+
lines.append(f" Revenue: ${{rev:.2f}}")
|
| 96 |
+
lines.append(f" Expenses: ${{exp:.2f}}")
|
| 97 |
+
lines.append(f" Net: ${{net:.2f}}")
|
| 98 |
+
lines.append("")
|
| 99 |
+
total_revenue += rev
|
| 100 |
+
lines.append(f"Total Revenue: ${{total_revenue:.2f}}")
|
| 101 |
+
lines.append("=== End of Report ===")
|
| 102 |
+
out = "\\n".join(lines)
|
| 103 |
+
print(out)
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
if len(sys.argv) < 2:
|
| 109 |
+
print("Usage: python format_report.py <input.json>", file=sys.stderr)
|
| 110 |
+
sys.exit(1)
|
| 111 |
+
format_report(sys.argv[1])
|
| 112 |
+
'''
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_seeded_task_3_scenario_is_deterministic() -> None:
|
| 116 |
+
env_a = DataOpsEnvironment()
|
| 117 |
+
env_b = DataOpsEnvironment()
|
| 118 |
+
try:
|
| 119 |
+
env_a.reset(task_id="task_3_hard_e2e", seed=17)
|
| 120 |
+
env_b.reset(task_id="task_3_hard_e2e", seed=17)
|
| 121 |
+
assert env_a.scenario.task_3 == env_b.scenario.task_3
|
| 122 |
+
finally:
|
| 123 |
+
env_a.close()
|
| 124 |
+
env_b.close()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_task_1_perfect_score_seeded() -> None:
|
| 128 |
+
env = DataOpsEnvironment()
|
| 129 |
+
env.reset(task_id="task_1_easy_anomaly", seed=7)
|
| 130 |
+
try:
|
| 131 |
+
obs = env.step(
|
| 132 |
+
DataOpsAction(
|
| 133 |
+
action_type="ExecuteSQL",
|
| 134 |
+
payload={"query": "DELETE FROM transactions WHERE amount IS NULL"},
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
assert obs.status == "success"
|
| 138 |
+
out = evaluate_task("task_1_easy_anomaly", env)
|
| 139 |
+
assert out["score"] == 1.0
|
| 140 |
+
finally:
|
| 141 |
+
env.close()
|
| 142 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def test_task_1_seeded_valid_rows_include_non_null_edge_amounts() -> None:
|
| 146 |
+
env = DataOpsEnvironment()
|
| 147 |
+
env.reset(task_id="task_1_easy_anomaly", seed=7)
|
| 148 |
+
try:
|
| 149 |
+
scenario = env.scenario.task_1
|
| 150 |
+
assert scenario is not None
|
| 151 |
+
amounts = [float(row["amount"]) for row in scenario.expected_rows]
|
| 152 |
+
assert any(amount == 0.0 for amount in amounts)
|
| 153 |
+
assert any(amount < 0.0 for amount in amounts)
|
| 154 |
+
finally:
|
| 155 |
+
env.close()
|
| 156 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def test_task_1_rewriting_corrupted_rows_scores_zero() -> None:
|
| 160 |
+
env = DataOpsEnvironment()
|
| 161 |
+
env.reset(task_id="task_1_easy_anomaly", seed=7)
|
| 162 |
+
try:
|
| 163 |
+
with sqlite3.connect(env.db_path) as conn:
|
| 164 |
+
conn.execute("UPDATE transactions SET amount = 0 WHERE amount IS NULL")
|
| 165 |
+
conn.commit()
|
| 166 |
+
out = evaluate_task("task_1_easy_anomaly", env)
|
| 167 |
+
assert out["score"] == 0.0
|
| 168 |
+
finally:
|
| 169 |
+
env.close()
|
| 170 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def test_task_1_deleting_non_null_adjustments_is_penalized() -> None:
|
| 174 |
+
env = DataOpsEnvironment()
|
| 175 |
+
env.reset(task_id="task_1_easy_anomaly", seed=7)
|
| 176 |
+
try:
|
| 177 |
+
obs = env.step(
|
| 178 |
+
DataOpsAction(
|
| 179 |
+
action_type="ExecuteSQL",
|
| 180 |
+
payload={
|
| 181 |
+
"query": "DELETE FROM transactions WHERE amount IS NULL OR amount <= 0"
|
| 182 |
+
},
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
assert obs.status == "success"
|
| 186 |
+
assert obs.reward is not None and obs.reward < 0
|
| 187 |
+
out = evaluate_task("task_1_easy_anomaly", env)
|
| 188 |
+
assert out["score"] == 0.0
|
| 189 |
+
finally:
|
| 190 |
+
env.close()
|
| 191 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def test_reset_only_scores_zero_across_tasks() -> None:
|
| 195 |
+
for task_id in (
|
| 196 |
+
"task_1_easy_anomaly",
|
| 197 |
+
"task_2_medium_syntax",
|
| 198 |
+
"task_3_hard_e2e",
|
| 199 |
+
):
|
| 200 |
+
env = DataOpsEnvironment()
|
| 201 |
+
try:
|
| 202 |
+
env.reset(task_id=task_id, seed=7)
|
| 203 |
+
out = evaluate_task(task_id, env)
|
| 204 |
+
assert out["score"] == 0.0
|
| 205 |
+
finally:
|
| 206 |
+
env.close()
|
| 207 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def test_task_1_broad_delete_with_where_is_penalized() -> None:
|
| 211 |
+
env = DataOpsEnvironment()
|
| 212 |
+
env.reset(task_id="task_1_easy_anomaly", seed=7)
|
| 213 |
+
try:
|
| 214 |
+
obs = env.step(
|
| 215 |
+
DataOpsAction(
|
| 216 |
+
action_type="ExecuteSQL",
|
| 217 |
+
payload={
|
| 218 |
+
"query": "DELETE FROM transactions WHERE amount IS NULL OR 1 = 1"
|
| 219 |
+
},
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
assert obs.status == "success"
|
| 223 |
+
assert obs.reward is not None and obs.reward < 0
|
| 224 |
+
assert env.evidence["task_1"]["destructive_sql_attempted"] is True
|
| 225 |
+
out = evaluate_task("task_1_easy_anomaly", env)
|
| 226 |
+
assert out["score"] == 0.0
|
| 227 |
+
finally:
|
| 228 |
+
env.close()
|
| 229 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def test_task_2_script_run_does_not_inherit_server_secrets(monkeypatch) -> None:
|
| 233 |
+
monkeypatch.setenv("API_KEY", "super-secret-value")
|
| 234 |
+
env = DataOpsEnvironment()
|
| 235 |
+
env.reset(task_id="task_2_medium_syntax", seed=11)
|
| 236 |
+
script = """\
|
| 237 |
+
import json
|
| 238 |
+
import os
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def process_data_stream(payloads):
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
print(json.dumps({"api_key": os.getenv("API_KEY"), "home": os.getenv("HOME")}))
|
| 247 |
+
"""
|
| 248 |
+
try:
|
| 249 |
+
env.step(
|
| 250 |
+
DataOpsAction(
|
| 251 |
+
action_type="WriteFile",
|
| 252 |
+
payload={"filepath": "broken_pipeline.py", "content": script},
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
run_obs = env.step(
|
| 256 |
+
DataOpsAction(
|
| 257 |
+
action_type="RunScript",
|
| 258 |
+
payload={"filepath": "broken_pipeline.py", "args": []},
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
assert run_obs.status == "success"
|
| 262 |
+
payload = json.loads((run_obs.stdout or "").strip())
|
| 263 |
+
assert payload["api_key"] is None
|
| 264 |
+
assert payload["home"] == env.workspace_dir
|
| 265 |
+
finally:
|
| 266 |
+
env.close()
|
| 267 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def test_task_2_perfect_score_seeded() -> None:
|
| 271 |
+
env = DataOpsEnvironment()
|
| 272 |
+
env.reset(task_id="task_2_medium_syntax", seed=11)
|
| 273 |
+
scenario = env.scenario.task_2
|
| 274 |
+
assert scenario is not None
|
| 275 |
+
try:
|
| 276 |
+
read_obs = env.step(
|
| 277 |
+
DataOpsAction(
|
| 278 |
+
action_type="ReadFile",
|
| 279 |
+
payload={"filepath": "broken_pipeline.py"},
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
assert read_obs.status == "success"
|
| 283 |
+
write_obs = env.step(
|
| 284 |
+
DataOpsAction(
|
| 285 |
+
action_type="WriteFile",
|
| 286 |
+
payload={
|
| 287 |
+
"filepath": "broken_pipeline.py",
|
| 288 |
+
"content": _fixed_pipeline_script(list(scenario.visible_batch)),
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
+
assert write_obs.status == "success"
|
| 293 |
+
pre_run = evaluate_task("task_2_medium_syntax", env)
|
| 294 |
+
assert 0.0 < pre_run["score"] < 1.0
|
| 295 |
+
run_obs = env.step(
|
| 296 |
+
DataOpsAction(
|
| 297 |
+
action_type="RunScript",
|
| 298 |
+
payload={"filepath": "broken_pipeline.py", "args": []},
|
| 299 |
+
)
|
| 300 |
+
)
|
| 301 |
+
assert run_obs.status == "success"
|
| 302 |
+
out = evaluate_task("task_2_medium_syntax", env)
|
| 303 |
+
assert out["score"] == 1.0
|
| 304 |
+
finally:
|
| 305 |
+
env.close()
|
| 306 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def test_task_2_print_only_stub_does_not_get_full_credit() -> None:
|
| 310 |
+
env = DataOpsEnvironment()
|
| 311 |
+
env.reset(task_id="task_2_medium_syntax", seed=11)
|
| 312 |
+
scenario = env.scenario.task_2
|
| 313 |
+
assert scenario is not None
|
| 314 |
+
stub = _visible_only_pipeline_stub(
|
| 315 |
+
list(scenario.visible_batch),
|
| 316 |
+
list(scenario.visible_expected),
|
| 317 |
+
)
|
| 318 |
+
try:
|
| 319 |
+
env.step(
|
| 320 |
+
DataOpsAction(
|
| 321 |
+
action_type="WriteFile",
|
| 322 |
+
payload={"filepath": "broken_pipeline.py", "content": stub},
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
out = evaluate_task("task_2_medium_syntax", env)
|
| 326 |
+
assert out["score"] < 0.5
|
| 327 |
+
finally:
|
| 328 |
+
env.close()
|
| 329 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def test_task_3_sql_policy_rejects_literal_table_name_bypass() -> None:
|
| 333 |
+
env = DataOpsEnvironment()
|
| 334 |
+
env.reset(task_id="task_3_hard_e2e", seed=19)
|
| 335 |
+
try:
|
| 336 |
+
obs = env.step(
|
| 337 |
+
DataOpsAction(
|
| 338 |
+
action_type="ExecuteSQL",
|
| 339 |
+
payload={
|
| 340 |
+
"query": (
|
| 341 |
+
"SELECT name FROM sqlite_master "
|
| 342 |
+
"WHERE 'daily_reports' = 'daily_reports'"
|
| 343 |
+
)
|
| 344 |
+
},
|
| 345 |
+
)
|
| 346 |
+
)
|
| 347 |
+
assert obs.status == "error"
|
| 348 |
+
assert "disallowed" in obs.message.lower()
|
| 349 |
+
finally:
|
| 350 |
+
env.close()
|
| 351 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def test_task_3_sql_policy_allows_cte_queries_over_daily_reports() -> None:
|
| 355 |
+
env = DataOpsEnvironment()
|
| 356 |
+
env.reset(task_id="task_3_hard_e2e", seed=19)
|
| 357 |
+
scenario = env.scenario.task_3
|
| 358 |
+
assert scenario is not None
|
| 359 |
+
try:
|
| 360 |
+
obs = env.step(
|
| 361 |
+
DataOpsAction(
|
| 362 |
+
action_type="ExecuteSQL",
|
| 363 |
+
payload={
|
| 364 |
+
"query": (
|
| 365 |
+
"WITH scoped AS ("
|
| 366 |
+
"SELECT department, revenue, expenses, headcount "
|
| 367 |
+
"FROM daily_reports "
|
| 368 |
+
f"WHERE report_date = '{scenario.target_date}'"
|
| 369 |
+
") "
|
| 370 |
+
"SELECT department, revenue, expenses, headcount "
|
| 371 |
+
"FROM scoped ORDER BY department"
|
| 372 |
+
)
|
| 373 |
+
},
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
assert obs.status == "success"
|
| 377 |
+
assert obs.sql_results
|
| 378 |
+
finally:
|
| 379 |
+
env.close()
|
| 380 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def test_task_3_perfect_score_requires_proven_workflow() -> None:
|
| 384 |
+
env = DataOpsEnvironment()
|
| 385 |
+
env.reset(task_id="task_3_hard_e2e", seed=19)
|
| 386 |
+
scenario = env.scenario.task_3
|
| 387 |
+
assert scenario is not None
|
| 388 |
+
try:
|
| 389 |
+
query = (
|
| 390 |
+
"SELECT department, revenue, expenses, headcount "
|
| 391 |
+
"FROM daily_reports "
|
| 392 |
+
f"WHERE report_date = '{scenario.target_date}' "
|
| 393 |
+
"ORDER BY department"
|
| 394 |
+
)
|
| 395 |
+
sql_obs = env.step(
|
| 396 |
+
DataOpsAction(
|
| 397 |
+
action_type="ExecuteSQL",
|
| 398 |
+
payload={"query": query},
|
| 399 |
+
)
|
| 400 |
+
)
|
| 401 |
+
assert sql_obs.status == "success"
|
| 402 |
+
rows = sql_obs.sql_results
|
| 403 |
+
assert rows is not None
|
| 404 |
+
|
| 405 |
+
write_json = env.step(
|
| 406 |
+
DataOpsAction(
|
| 407 |
+
action_type="WriteFile",
|
| 408 |
+
payload={"filepath": "report_data.json", "content": json.dumps(rows)},
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
assert write_json.status == "success"
|
| 412 |
+
|
| 413 |
+
write_script = env.step(
|
| 414 |
+
DataOpsAction(
|
| 415 |
+
action_type="WriteFile",
|
| 416 |
+
payload={
|
| 417 |
+
"filepath": "format_report.py",
|
| 418 |
+
"content": _fixed_format_script(scenario.target_date),
|
| 419 |
+
},
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
assert write_script.status == "success"
|
| 423 |
+
|
| 424 |
+
run_obs = env.step(
|
| 425 |
+
DataOpsAction(
|
| 426 |
+
action_type="RunScript",
|
| 427 |
+
payload={"filepath": "format_report.py", "args": ["report_data.json"]},
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
assert run_obs.status == "success"
|
| 431 |
+
body = (run_obs.stdout or "").strip()
|
| 432 |
+
|
| 433 |
+
email_obs = env.step(
|
| 434 |
+
DataOpsAction(
|
| 435 |
+
action_type="SendEmail",
|
| 436 |
+
payload={
|
| 437 |
+
"to_email": scenario.recipient,
|
| 438 |
+
"subject": scenario.subject,
|
| 439 |
+
"body": body,
|
| 440 |
+
},
|
| 441 |
+
)
|
| 442 |
+
)
|
| 443 |
+
assert email_obs.status == "success"
|
| 444 |
+
|
| 445 |
+
out = evaluate_task("task_3_hard_e2e", env)
|
| 446 |
+
assert out["score"] == 1.0
|
| 447 |
+
finally:
|
| 448 |
+
env.close()
|
| 449 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def test_task_3_equivalent_relative_input_path_still_scores_perfect() -> None:
|
| 453 |
+
env = DataOpsEnvironment()
|
| 454 |
+
env.reset(task_id="task_3_hard_e2e", seed=29)
|
| 455 |
+
scenario = env.scenario.task_3
|
| 456 |
+
assert scenario is not None
|
| 457 |
+
try:
|
| 458 |
+
query = (
|
| 459 |
+
"SELECT department, revenue, expenses, headcount "
|
| 460 |
+
"FROM daily_reports "
|
| 461 |
+
f"WHERE report_date = '{scenario.target_date}' "
|
| 462 |
+
"ORDER BY department"
|
| 463 |
+
)
|
| 464 |
+
sql_obs = env.step(
|
| 465 |
+
DataOpsAction(
|
| 466 |
+
action_type="ExecuteSQL",
|
| 467 |
+
payload={"query": query},
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
assert sql_obs.status == "success"
|
| 471 |
+
rows = sql_obs.sql_results
|
| 472 |
+
assert rows is not None
|
| 473 |
+
|
| 474 |
+
env.step(
|
| 475 |
+
DataOpsAction(
|
| 476 |
+
action_type="WriteFile",
|
| 477 |
+
payload={"filepath": "report_data.json", "content": json.dumps(rows)},
|
| 478 |
+
)
|
| 479 |
+
)
|
| 480 |
+
env.step(
|
| 481 |
+
DataOpsAction(
|
| 482 |
+
action_type="WriteFile",
|
| 483 |
+
payload={
|
| 484 |
+
"filepath": "format_report.py",
|
| 485 |
+
"content": _fixed_format_script(scenario.target_date),
|
| 486 |
+
},
|
| 487 |
+
)
|
| 488 |
+
)
|
| 489 |
+
run_obs = env.step(
|
| 490 |
+
DataOpsAction(
|
| 491 |
+
action_type="RunScript",
|
| 492 |
+
payload={"filepath": "format_report.py", "args": ["./report_data.json"]},
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
assert run_obs.status == "success"
|
| 496 |
+
|
| 497 |
+
env.step(
|
| 498 |
+
DataOpsAction(
|
| 499 |
+
action_type="SendEmail",
|
| 500 |
+
payload={
|
| 501 |
+
"to_email": scenario.recipient,
|
| 502 |
+
"subject": scenario.subject,
|
| 503 |
+
"body": (run_obs.stdout or "").strip(),
|
| 504 |
+
},
|
| 505 |
+
)
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
out = evaluate_task("task_3_hard_e2e", env)
|
| 509 |
+
assert out["score"] == 1.0
|
| 510 |
+
finally:
|
| 511 |
+
env.close()
|
| 512 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def test_task_3_fabricated_email_only_scores_low() -> None:
|
| 516 |
+
env = DataOpsEnvironment()
|
| 517 |
+
env.reset(task_id="task_3_hard_e2e", seed=23)
|
| 518 |
+
scenario = env.scenario.task_3
|
| 519 |
+
assert scenario is not None
|
| 520 |
+
try:
|
| 521 |
+
fake_body = build_task_3_report(list(scenario.expected_rows), scenario.target_date)
|
| 522 |
+
email_obs = env.step(
|
| 523 |
+
DataOpsAction(
|
| 524 |
+
action_type="SendEmail",
|
| 525 |
+
payload={
|
| 526 |
+
"to_email": scenario.recipient,
|
| 527 |
+
"subject": scenario.subject,
|
| 528 |
+
"body": fake_body,
|
| 529 |
+
},
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
+
assert email_obs.status == "success"
|
| 533 |
+
out = evaluate_task("task_3_hard_e2e", env)
|
| 534 |
+
assert out["score"] <= 0.10
|
| 535 |
+
finally:
|
| 536 |
+
env.close()
|
| 537 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def test_task_3_reading_formatter_source_awards_progress_signal() -> None:
|
| 541 |
+
env = DataOpsEnvironment()
|
| 542 |
+
env.reset(task_id="task_3_hard_e2e", seed=31)
|
| 543 |
+
try:
|
| 544 |
+
obs = env.step(
|
| 545 |
+
DataOpsAction(
|
| 546 |
+
action_type="ReadFile",
|
| 547 |
+
payload={"filepath": "format_report.py"},
|
| 548 |
+
)
|
| 549 |
+
)
|
| 550 |
+
assert obs.status == "success"
|
| 551 |
+
assert obs.reward is not None and obs.reward > 0
|
| 552 |
+
finally:
|
| 553 |
+
env.close()
|
| 554 |
+
shutil.rmtree(env.workspace_dir, ignore_errors=True)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def test_tasks_endpoint_exposes_manifest_metadata() -> None:
|
| 558 |
+
with TestClient(app) as client:
|
| 559 |
+
response = client.get("/tasks")
|
| 560 |
+
payload = response.json()
|
| 561 |
+
assert response.status_code == 200
|
| 562 |
+
assert len(payload["tasks"]) == 3
|
| 563 |
+
assert payload["tasks"][0]["difficulty"] == "easy"
|
| 564 |
+
assert "action_schema" in payload
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def test_public_grader_hides_details_by_default(monkeypatch) -> None:
|
| 568 |
+
# Do not leak grader details when PUBLIC_GRADER_DETAILS is unset/false (ignore dev .env).
|
| 569 |
+
monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "false")
|
| 570 |
+
with TestClient(app) as client:
|
| 571 |
+
reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5})
|
| 572 |
+
assert reset.status_code == 200
|
| 573 |
+
grade = client.get("/grader")
|
| 574 |
+
assert grade.status_code == 200
|
| 575 |
+
payload = grade.json()
|
| 576 |
+
assert "score" in payload
|
| 577 |
+
assert "details" not in payload
|
tests/test_inference_api.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration-style checks for inference wiring and transport/session flows."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from subprocess import CompletedProcess
|
| 10 |
+
from types import SimpleNamespace
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from fastapi.testclient import TestClient
|
| 14 |
+
|
| 15 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 16 |
+
if str(_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(_ROOT))
|
| 18 |
+
|
| 19 |
+
import inference # noqa: E402
|
| 20 |
+
import env_loader # noqa: E402
|
| 21 |
+
import server.app as app_module # noqa: E402
|
| 22 |
+
from client import DataOpsEnvClient # noqa: E402
|
| 23 |
+
from models import DataOpsAction # noqa: E402
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _FakeResponse:
|
| 27 |
+
def __init__(self, payload: dict[str, Any]) -> None:
|
| 28 |
+
self._payload = payload
|
| 29 |
+
|
| 30 |
+
def raise_for_status(self) -> None:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
def json(self) -> dict[str, Any]:
|
| 34 |
+
return self._payload
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class _FakeHTTPSession:
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
self.urls: list[str] = []
|
| 40 |
+
self._step_count = 0
|
| 41 |
+
|
| 42 |
+
def request(
|
| 43 |
+
self,
|
| 44 |
+
method: str,
|
| 45 |
+
url: str,
|
| 46 |
+
timeout: float | None = None,
|
| 47 |
+
**kwargs: Any,
|
| 48 |
+
) -> _FakeResponse:
|
| 49 |
+
del method, timeout, kwargs
|
| 50 |
+
self.urls.append(url)
|
| 51 |
+
if url.endswith("/reset"):
|
| 52 |
+
self._step_count = 0
|
| 53 |
+
return _FakeResponse(
|
| 54 |
+
{
|
| 55 |
+
"observation": {
|
| 56 |
+
"status": "success",
|
| 57 |
+
"message": "Repair the ETL job.",
|
| 58 |
+
},
|
| 59 |
+
"reward": 0.0,
|
| 60 |
+
"done": False,
|
| 61 |
+
}
|
| 62 |
+
)
|
| 63 |
+
if url.endswith("/step"):
|
| 64 |
+
self._step_count += 1
|
| 65 |
+
return _FakeResponse(
|
| 66 |
+
{
|
| 67 |
+
"observation": {
|
| 68 |
+
"status": "success",
|
| 69 |
+
"message": "Read ok.",
|
| 70 |
+
},
|
| 71 |
+
"reward": 0.0,
|
| 72 |
+
"done": self._step_count >= 2,
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
if url.endswith("/grader/task_2_medium_syntax"):
|
| 76 |
+
return _FakeResponse({"score": 0.25})
|
| 77 |
+
raise AssertionError(f"Unexpected URL requested: {url}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class _FakeChatCompletions:
|
| 81 |
+
def __init__(self, messages: list[Any]) -> None:
|
| 82 |
+
self._messages = iter(messages)
|
| 83 |
+
|
| 84 |
+
def create(self, **kwargs: Any) -> Any:
|
| 85 |
+
del kwargs
|
| 86 |
+
message = next(self._messages)
|
| 87 |
+
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class _FakeClient:
|
| 91 |
+
def __init__(self, messages: list[Any]) -> None:
|
| 92 |
+
self.chat = SimpleNamespace(completions=_FakeChatCompletions(messages))
|
| 93 |
+
self.base_url = "https://model.local/v1"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _tool_message(name: str, arguments: dict[str, Any]) -> Any:
|
| 97 |
+
return SimpleNamespace(
|
| 98 |
+
tool_calls=[
|
| 99 |
+
SimpleNamespace(
|
| 100 |
+
id="call-1",
|
| 101 |
+
function=SimpleNamespace(
|
| 102 |
+
name=name,
|
| 103 |
+
arguments=json.dumps(arguments),
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_inference_run_task_uses_env_base_url(monkeypatch, capsys) -> None:
|
| 111 |
+
fake_http = _FakeHTTPSession()
|
| 112 |
+
fake_client = _FakeClient(
|
| 113 |
+
[
|
| 114 |
+
_tool_message("read_file", {"filepath": "broken_pipeline.py"}),
|
| 115 |
+
SimpleNamespace(tool_calls=[]),
|
| 116 |
+
_tool_message("invoke_python", {"filepath": "broken_pipeline.py", "args": []}),
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
|
| 120 |
+
monkeypatch.setattr(inference, "API_BASE_URL", "https://model.local/v1")
|
| 121 |
+
monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
|
| 122 |
+
|
| 123 |
+
score = inference.run_task(
|
| 124 |
+
fake_client,
|
| 125 |
+
fake_http,
|
| 126 |
+
"task_2_medium_syntax",
|
| 127 |
+
max_turns=4,
|
| 128 |
+
seed=3,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
assert score == 0.25
|
| 132 |
+
assert fake_http.urls
|
| 133 |
+
assert all(url.startswith("http://env.local") for url in fake_http.urls)
|
| 134 |
+
assert all("model.local" not in url for url in fake_http.urls)
|
| 135 |
+
|
| 136 |
+
stdout = capsys.readouterr().out
|
| 137 |
+
assert "[START]" in stdout
|
| 138 |
+
assert "[STEP]" in stdout
|
| 139 |
+
assert "[END]" in stdout
|
| 140 |
+
assert "success=false" in stdout
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_inference_emits_grader_details_to_stderr_when_enabled(monkeypatch, capsys) -> None:
|
| 144 |
+
class _DetailedHTTPSession(_FakeHTTPSession):
|
| 145 |
+
def request(
|
| 146 |
+
self,
|
| 147 |
+
method: str,
|
| 148 |
+
url: str,
|
| 149 |
+
timeout: float | None = None,
|
| 150 |
+
**kwargs: Any,
|
| 151 |
+
) -> _FakeResponse:
|
| 152 |
+
if url.endswith("/grader/task_2_medium_syntax"):
|
| 153 |
+
return _FakeResponse(
|
| 154 |
+
{
|
| 155 |
+
"task_id": "task_2_medium_syntax",
|
| 156 |
+
"score": 0.25,
|
| 157 |
+
"details": {"reason": "Visible repair only"},
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
return super().request(method, url, timeout=timeout, **kwargs)
|
| 161 |
+
|
| 162 |
+
fake_http = _DetailedHTTPSession()
|
| 163 |
+
fake_client = _FakeClient(
|
| 164 |
+
[_tool_message("read_file", {"filepath": "broken_pipeline.py"})]
|
| 165 |
+
)
|
| 166 |
+
monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "true")
|
| 167 |
+
monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
|
| 168 |
+
monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
|
| 169 |
+
|
| 170 |
+
inference.run_task(
|
| 171 |
+
fake_client,
|
| 172 |
+
fake_http,
|
| 173 |
+
"task_2_medium_syntax",
|
| 174 |
+
max_turns=1,
|
| 175 |
+
seed=3,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
stderr = capsys.readouterr().err.strip()
|
| 179 |
+
assert stderr
|
| 180 |
+
assert json.loads(stderr)["details"]["reason"] == "Visible repair only"
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def test_baseline_endpoint_passes_env_base_url(monkeypatch) -> None:
|
| 184 |
+
captured: dict[str, Any] = {}
|
| 185 |
+
|
| 186 |
+
def fake_run(
|
| 187 |
+
command: list[str],
|
| 188 |
+
*,
|
| 189 |
+
cwd: str,
|
| 190 |
+
capture_output: bool,
|
| 191 |
+
text: bool,
|
| 192 |
+
timeout: float,
|
| 193 |
+
env: dict[str, str],
|
| 194 |
+
) -> CompletedProcess[str]:
|
| 195 |
+
captured["command"] = command
|
| 196 |
+
captured["cwd"] = cwd
|
| 197 |
+
captured["capture_output"] = capture_output
|
| 198 |
+
captured["text"] = text
|
| 199 |
+
captured["timeout"] = timeout
|
| 200 |
+
captured["env"] = env
|
| 201 |
+
stdout = "\n".join(
|
| 202 |
+
[
|
| 203 |
+
"[START] task=task_1_easy_anomaly env=dataops_env model=fake-model",
|
| 204 |
+
"[END] success=true steps=1 score=1.000 rewards=1.00",
|
| 205 |
+
json.dumps(
|
| 206 |
+
{
|
| 207 |
+
"scores": {"task_1_easy_anomaly": 1.0},
|
| 208 |
+
"grades": {
|
| 209 |
+
"task_1_easy_anomaly": {
|
| 210 |
+
"task_id": "task_1_easy_anomaly",
|
| 211 |
+
"score": 1.0,
|
| 212 |
+
"details": {"reason": "Perfect"},
|
| 213 |
+
}
|
| 214 |
+
},
|
| 215 |
+
"average": 1.0,
|
| 216 |
+
"model": "fake-model",
|
| 217 |
+
"metadata": {"env_base_url": "http://127.0.0.1:7860"},
|
| 218 |
+
}
|
| 219 |
+
),
|
| 220 |
+
]
|
| 221 |
+
)
|
| 222 |
+
stderr = json.dumps({"task_id": "task_1_easy_anomaly", "score": 1.0})
|
| 223 |
+
return CompletedProcess(command, 0, stdout=stdout, stderr=stderr)
|
| 224 |
+
|
| 225 |
+
monkeypatch.setenv("API_KEY", "test-key")
|
| 226 |
+
monkeypatch.delenv("ENV_BASE_URL", raising=False)
|
| 227 |
+
monkeypatch.setattr(app_module.subprocess, "run", fake_run)
|
| 228 |
+
|
| 229 |
+
with TestClient(app_module.app) as client:
|
| 230 |
+
response = client.post(
|
| 231 |
+
"/baseline",
|
| 232 |
+
json={
|
| 233 |
+
"task_ids": ["task_1_easy_anomaly"],
|
| 234 |
+
"seed": 7,
|
| 235 |
+
"max_turns": 5,
|
| 236 |
+
},
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
assert response.status_code == 200
|
| 240 |
+
assert "[START] task=task_1_easy_anomaly" in response.json()["stdout"]
|
| 241 |
+
assert response.json()["stderr"] == json.dumps({"task_id": "task_1_easy_anomaly", "score": 1.0})
|
| 242 |
+
assert response.json()["scores"]["task_1_easy_anomaly"] == 1.0
|
| 243 |
+
assert response.json()["grades"]["task_1_easy_anomaly"]["details"]["reason"] == "Perfect"
|
| 244 |
+
assert captured["env"]["ENV_BASE_URL"] == "http://127.0.0.1:7860"
|
| 245 |
+
assert "--seed" in captured["command"]
|
| 246 |
+
assert "--max-turns" in captured["command"]
|
| 247 |
+
assert "--task" in captured["command"]
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test_inference_default_api_base_url_uses_google_for_api_key(
|
| 251 |
+
monkeypatch,
|
| 252 |
+
) -> None:
|
| 253 |
+
monkeypatch.setenv("API_KEY", "test-key")
|
| 254 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 255 |
+
monkeypatch.delenv("API_BASE_URL", raising=False)
|
| 256 |
+
|
| 257 |
+
assert (
|
| 258 |
+
inference._resolve_api_base_url()
|
| 259 |
+
== inference.DEFAULT_GOOGLE_OPENAI_BASE_URL
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def test_inference_default_api_base_url_uses_hf_router_for_hf_token(
|
| 264 |
+
monkeypatch,
|
| 265 |
+
) -> None:
|
| 266 |
+
monkeypatch.setenv("HF_TOKEN", "test-token")
|
| 267 |
+
monkeypatch.delenv("API_KEY", raising=False)
|
| 268 |
+
monkeypatch.delenv("API_BASE_URL", raising=False)
|
| 269 |
+
|
| 270 |
+
assert inference._resolve_api_base_url() == inference.DEFAULT_HF_OPENAI_BASE_URL
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def test_session_id_header_can_resume_http_episode() -> None:
|
| 274 |
+
with TestClient(app_module.app) as client:
|
| 275 |
+
reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5})
|
| 276 |
+
assert reset.status_code == 200
|
| 277 |
+
session_id = reset.headers["X-Session-ID"]
|
| 278 |
+
client.cookies.clear()
|
| 279 |
+
|
| 280 |
+
state = client.get("/state", headers={"X-Session-ID": session_id})
|
| 281 |
+
assert state.status_code == 200
|
| 282 |
+
payload = state.json()
|
| 283 |
+
assert payload["task_id"] == "task_1_easy_anomaly"
|
| 284 |
+
assert payload["seed"] == 5
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def test_reset_replaces_unknown_client_supplied_session_id() -> None:
|
| 288 |
+
with TestClient(app_module.app) as client:
|
| 289 |
+
reset = client.post(
|
| 290 |
+
"/reset?task_id=task_1_easy_anomaly",
|
| 291 |
+
headers={"X-Session-ID": "attacker-chosen-session"},
|
| 292 |
+
json={"seed": 4},
|
| 293 |
+
)
|
| 294 |
+
assert reset.status_code == 200
|
| 295 |
+
issued_session_id = reset.headers["X-Session-ID"]
|
| 296 |
+
assert issued_session_id != "attacker-chosen-session"
|
| 297 |
+
|
| 298 |
+
client.cookies.clear()
|
| 299 |
+
forged_state = client.get("/state", headers={"X-Session-ID": "attacker-chosen-session"})
|
| 300 |
+
assert forged_state.status_code == 400
|
| 301 |
+
|
| 302 |
+
restored_state = client.get("/state", headers={"X-Session-ID": issued_session_id})
|
| 303 |
+
assert restored_state.status_code == 200
|
| 304 |
+
assert restored_state.json()["seed"] == 4
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def test_websocket_reset_state_and_step_flow() -> None:
|
| 308 |
+
with TestClient(app_module.app) as client:
|
| 309 |
+
with client.websocket_connect("/ws") as websocket:
|
| 310 |
+
websocket.send_json(
|
| 311 |
+
{
|
| 312 |
+
"type": "reset",
|
| 313 |
+
"data": {"task_id": "task_1_easy_anomaly", "seed": 3},
|
| 314 |
+
}
|
| 315 |
+
)
|
| 316 |
+
reset_payload = websocket.receive_json()
|
| 317 |
+
assert reset_payload["data"]["observation"]["status"] == "success"
|
| 318 |
+
|
| 319 |
+
websocket.send_json({"type": "state"})
|
| 320 |
+
state_payload = websocket.receive_json()
|
| 321 |
+
assert state_payload["data"]["task_id"] == "task_1_easy_anomaly"
|
| 322 |
+
|
| 323 |
+
websocket.send_json(
|
| 324 |
+
{
|
| 325 |
+
"type": "step",
|
| 326 |
+
"data": {
|
| 327 |
+
"action_type": "ExecuteSQL",
|
| 328 |
+
"payload": {
|
| 329 |
+
"query": (
|
| 330 |
+
"SELECT id, amount FROM transactions "
|
| 331 |
+
"WHERE amount IS NULL ORDER BY id"
|
| 332 |
+
)
|
| 333 |
+
},
|
| 334 |
+
},
|
| 335 |
+
}
|
| 336 |
+
)
|
| 337 |
+
step_payload = websocket.receive_json()
|
| 338 |
+
assert step_payload["data"]["observation"]["status"] == "success"
|
| 339 |
+
assert step_payload["data"]["observation"]["sql_results"]
|
| 340 |
+
|
| 341 |
+
websocket.send_json({"type": "close", "data": {}})
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def test_http_client_overlays_top_level_reward_and_done() -> None:
|
| 345 |
+
class _FakeSession:
|
| 346 |
+
def post(self, url: str, **kwargs: Any) -> _FakeResponse:
|
| 347 |
+
del kwargs
|
| 348 |
+
if url.endswith("/reset"):
|
| 349 |
+
return _FakeResponse(
|
| 350 |
+
{
|
| 351 |
+
"observation": {"status": "success", "message": "ready"},
|
| 352 |
+
"reward": 0.0,
|
| 353 |
+
"done": False,
|
| 354 |
+
}
|
| 355 |
+
)
|
| 356 |
+
if url.endswith("/step"):
|
| 357 |
+
return _FakeResponse(
|
| 358 |
+
{
|
| 359 |
+
"observation": {"status": "success", "message": "ok"},
|
| 360 |
+
"reward": 0.25,
|
| 361 |
+
"done": True,
|
| 362 |
+
}
|
| 363 |
+
)
|
| 364 |
+
raise AssertionError(f"Unexpected URL requested: {url}")
|
| 365 |
+
|
| 366 |
+
def get(self, url: str, **kwargs: Any) -> _FakeResponse:
|
| 367 |
+
del kwargs
|
| 368 |
+
raise AssertionError(f"Unexpected URL requested: {url}")
|
| 369 |
+
|
| 370 |
+
def close(self) -> None:
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
client = DataOpsEnvClient(base_url="http://env.local")
|
| 374 |
+
client._session = _FakeSession()
|
| 375 |
+
try:
|
| 376 |
+
reset_obs = client.reset(task_id="task_1_easy_anomaly", seed=5)
|
| 377 |
+
assert reset_obs.reward == 0.0
|
| 378 |
+
assert reset_obs.done is False
|
| 379 |
+
|
| 380 |
+
step_obs = client.step(
|
| 381 |
+
DataOpsAction(
|
| 382 |
+
action_type="ExecuteSQL",
|
| 383 |
+
payload={"query": "SELECT 1"},
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
assert step_obs.reward == 0.25
|
| 387 |
+
assert step_obs.done is True
|
| 388 |
+
finally:
|
| 389 |
+
client.close()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def test_env_loader_uses_root_env_to_find_secondary_env_file(
|
| 393 |
+
tmp_path: Path, monkeypatch
|
| 394 |
+
) -> None:
|
| 395 |
+
monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path)
|
| 396 |
+
monkeypatch.delenv("PUBLIC_GRADER_DETAILS", raising=False)
|
| 397 |
+
monkeypatch.delenv("MODEL_NAME", raising=False)
|
| 398 |
+
|
| 399 |
+
(tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8")
|
| 400 |
+
(tmp_path / ".env.dev").write_text(
|
| 401 |
+
"PUBLIC_GRADER_DETAILS=true\nMODEL_NAME=debug-model\n",
|
| 402 |
+
encoding="utf-8",
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
env_loader.load_env()
|
| 406 |
+
|
| 407 |
+
assert os.getenv("PUBLIC_GRADER_DETAILS") == "true"
|
| 408 |
+
assert os.getenv("MODEL_NAME") == "debug-model"
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|