Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- Dockerfile +95 -0
- README.md +151 -6
- __init__.py +29 -0
- _cli.py +47 -0
- actors/__init__.py +1 -0
- actors/dba_oracle.py +183 -0
- actors/engineering_manager.py +81 -0
- client.py +121 -0
- engine/__init__.py +1 -0
- engine/drift.py +242 -0
- engine/profiler.py +298 -0
- engine/reward.py +437 -0
- engine/runtime.py +101 -0
- engine/verifier.py +101 -0
- models.py +364 -0
- openenv.yaml +7 -0
- pyproject.toml +163 -0
- scenarios/__init__.py +72 -0
- scenarios/_fixtures.py +119 -0
- scenarios/base.py +230 -0
- scenarios/scenario_01_correlated_subquery.py +93 -0
- scenarios/scenario_02_select_star_join.py +119 -0
- scenarios/scenario_03_cartesian_join.py +93 -0
- scenarios/scenario_04_distinct_groupby.py +73 -0
- scenarios/scenario_05_nested_subquery.py +100 -0
- scenarios/scenario_06_having_as_where.py +92 -0
- scenarios/scenario_07_drift_column_rename.py +73 -0
- scenarios/scenario_08_drift_date_format.py +93 -0
- scenarios/scenario_09_drift_enum_rule.py +85 -0
- scenarios/scenario_10_drift_field_deprecation.py +74 -0
- server/__init__.py +5 -0
- server/app.py +89 -0
- server/requirements.txt +116 -0
- server/settings.py +31 -0
- server/sql_drift_env_environment.py +899 -0
- skill_library/__init__.py +54 -0
- skill_library/entries.py +93 -0
- skill_library/retrieval.py +85 -0
- skill_library/seeds.py +126 -0
- skill_library/store.py +257 -0
- sql_drift_grpo_training.ipynb +373 -0
- training/__init__.py +27 -0
- training/config.py +127 -0
- training/eval.py +519 -0
- training/evidence/grpo_components_combined.png +3 -0
- training/evidence/grpo_loss_curve.png +0 -0
- training/evidence/grpo_metrics.csv +21 -0
- training/evidence/grpo_r_correct_curve.png +0 -0
- training/evidence/grpo_r_drift_curve.png +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
training/evidence/grpo_components_combined.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
training/evidence/grpo_reward_curve.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
# ─── IMAGE PINNING ────────────────────────────────────────────────────────────
|
| 14 |
+
# Using :latest is a moving target: a base-image update between local build
|
| 15 |
+
# and judging rebuild can pull a different Python / CUDA / OS version.
|
| 16 |
+
#
|
| 17 |
+
# TO PIN (run once, commit the digest):
|
| 18 |
+
# docker pull ghcr.io/meta-pytorch/openenv-base:latest
|
| 19 |
+
# docker inspect --format='{{index .RepoDigests 0}}' ghcr.io/meta-pytorch/openenv-base:latest
|
| 20 |
+
# # → ghcr.io/meta-pytorch/openenv-base@sha256:<digest>
|
| 21 |
+
#
|
| 22 |
+
# Then replace the ARG default below with the digest, e.g.:
|
| 23 |
+
# ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base@sha256:<digest>
|
| 24 |
+
#
|
| 25 |
+
# The sha256 digest is immutable; the image will not change between builds.
|
| 26 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 28 |
+
FROM ${BASE_IMAGE} AS builder
|
| 29 |
+
|
| 30 |
+
WORKDIR /app
|
| 31 |
+
|
| 32 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 33 |
+
RUN apt-get update && \
|
| 34 |
+
apt-get install -y --no-install-recommends git && \
|
| 35 |
+
rm -rf /var/lib/apt/lists/*
|
| 36 |
+
|
| 37 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 38 |
+
ARG BUILD_MODE=in-repo
|
| 39 |
+
ARG ENV_NAME=sql_drift_env
|
| 40 |
+
|
| 41 |
+
# Copy environment code (always at root of build context)
|
| 42 |
+
COPY . /app/env
|
| 43 |
+
|
| 44 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 45 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 46 |
+
WORKDIR /app/env
|
| 47 |
+
|
| 48 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 49 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 50 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 51 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 52 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
# Install dependencies from the frozen lock file. Failing fast when uv.lock
|
| 56 |
+
# is absent prevents a silent re-resolution that could pull different versions
|
| 57 |
+
# than the ones tested (see design/codereview.md, dependency pins).
|
| 58 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 59 |
+
uv sync --frozen --no-install-project --no-editable
|
| 60 |
+
|
| 61 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 62 |
+
uv sync --frozen --no-editable
|
| 63 |
+
|
| 64 |
+
# Final runtime stage — inherits the same pinned/latest image chosen above.
|
| 65 |
+
FROM ${BASE_IMAGE}
|
| 66 |
+
|
| 67 |
+
WORKDIR /app
|
| 68 |
+
|
| 69 |
+
# Copy the uv-managed Python interpreter from the builder so that the
|
| 70 |
+
# shebangs baked into ``/app/env/.venv/bin/*`` (e.g. uvicorn) resolve.
|
| 71 |
+
# Without this, execve() on those scripts silently falls back to the
|
| 72 |
+
# system ``/usr/local/bin/uvicorn`` (Python 3.11) and loads the wrong
|
| 73 |
+
# site-packages, yielding ``ModuleNotFoundError: duckdb`` at startup.
|
| 74 |
+
COPY --from=builder /root/.local/share/uv /root/.local/share/uv
|
| 75 |
+
|
| 76 |
+
# Copy the environment code (includes the builder's .venv at /app/env/.venv)
|
| 77 |
+
COPY --from=builder /app/env /app/env
|
| 78 |
+
|
| 79 |
+
# Point PATH + PYTHONPATH at the venv that shipped with the env so
|
| 80 |
+
# both ``uvicorn ...`` and ``python -m uvicorn ...`` resolve against
|
| 81 |
+
# the right interpreter.
|
| 82 |
+
ENV PATH="/app/env/.venv/bin:$PATH"
|
| 83 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 84 |
+
|
| 85 |
+
# Health check (hits the stateless /health endpoint exposed by the
|
| 86 |
+
# OpenEnv FastAPI factory).
|
| 87 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 88 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 89 |
+
|
| 90 |
+
# Run the FastAPI server via the venv's Python interpreter directly
|
| 91 |
+
# (bypassing any PATH-lookup ambiguity with the base image's system
|
| 92 |
+
# uvicorn).
|
| 93 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 94 |
+
CMD ["/app/env/.venv/bin/python", "-m", "uvicorn", "server.app:app", \
|
| 95 |
+
"--host", "0.0.0.0", "--port", "8000", "--app-dir", "/app/env"]
|
README.md
CHANGED
|
@@ -1,10 +1,155 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQLDrift
|
| 3 |
+
emoji: 🐘
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
+
short_description: SQL repair and optimization under live schema drift
|
| 9 |
+
base_path: /web
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# SQLDrift
|
| 13 |
+
|
| 14 |
+
> An OpenEnv gym that teaches an LLM agent to **repair and optimize SQL** when
|
| 15 |
+
> the database schema and business rules shift out from under it.
|
| 16 |
+
|
| 17 |
+
SQLDrift is a production-grade [OpenEnv](https://github.com/meta-pytorch/OpenEnv)
|
| 18 |
+
environment designed for RL fine-tuning with TRL's
|
| 19 |
+
[`GRPOTrainer`](https://huggingface.co/docs/trl/openenv) (plus
|
| 20 |
+
[Unsloth](https://github.com/unslothai/unsloth) LoRA on consumer GPUs).
|
| 21 |
+
Each episode hands the agent a slow baseline SQL query, a lightly-populated
|
| 22 |
+
DuckDB database, a toolbox of read-only introspection / execution tools, and a
|
| 23 |
+
**25-step budget**. Midway through the episode a schema or business-rule
|
| 24 |
+
**drift event** may fire; the agent must read the deploy changelog, re-diagnose,
|
| 25 |
+
and submit a rewrite that is both (a) semantically correct and (b) ≥ 1.2× faster
|
| 26 |
+
than the baseline. A hierarchical rubric turns that into six reward signals
|
| 27 |
+
(correctness, drift-adapt, speedup, step-tax, gatekeepers, consult-dba).
|
| 28 |
+
|
| 29 |
+
Highlights:
|
| 30 |
+
|
| 31 |
+
- **10 hand-crafted scenarios** — 6 classic anti-patterns (correlated
|
| 32 |
+
subqueries, `SELECT *` joins, Cartesian joins, `DISTINCT` groupbys, nested
|
| 33 |
+
subqueries, `HAVING`-as-`WHERE`) and 4 drift events (column rename, date
|
| 34 |
+
format flip, enum rule split, field deprecation).
|
| 35 |
+
- **Deterministic fixtures** — 15–30-column schemas, 2–4 tables per scenario,
|
| 36 |
+
regenerated in-process at `reset()` from a seed. No Parquet files, no
|
| 37 |
+
pre-baked sqlite dumps.
|
| 38 |
+
- **Sqlglot-canonicalised baseline-verbatim gate** — agents that rename
|
| 39 |
+
whitespace and resubmit the baseline cannot farm the +1.0 correct bonus.
|
| 40 |
+
- **Skill library** — 12 pre-seeded playbook/drift-card entries plus an
|
| 41 |
+
on-disk, `fcntl.flock`-guarded JSON store that grows as the agent solves
|
| 42 |
+
episodes, surfaced next `reset()` via Jaccard retrieval.
|
| 43 |
+
- **Feature-flagged DBA Oracle** — 3-tier escalating hints per scenario,
|
| 44 |
+
penalised by the `ConsultDBA` rubric. Off by default.
|
| 45 |
+
- **Dockerised** — `server/Dockerfile` layers the env on
|
| 46 |
+
`ghcr.io/meta-pytorch/openenv-base` and exposes `/health`, `/reset`,
|
| 47 |
+
`/step`, `/ws` out-of-the-box.
|
| 48 |
+
|
| 49 |
+
## Quick start
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
git clone <your-fork>/sql_drift_env.git
|
| 53 |
+
cd sql_drift_env
|
| 54 |
+
uv sync # install deps
|
| 55 |
+
uv run pytest -q # 300+ tests, ~90s on CPU
|
| 56 |
+
uv run uvicorn server.app:app --reload # serve the env on :8000
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Or via Docker:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
docker build -f server/Dockerfile -t sql-drift-env:latest .
|
| 63 |
+
docker run -p 8000:8000 sql-drift-env:latest
|
| 64 |
+
curl -s http://localhost:8000/health
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Programmatic rollout
|
| 68 |
+
|
| 69 |
+
A five-step in-process episode against scenario `01_correlated_subquery`:
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
from client import SqlDriftEnv
|
| 73 |
+
from server import SqlDriftEnvironment
|
| 74 |
+
|
| 75 |
+
env = SqlDriftEnvironment()
|
| 76 |
+
obs = env.reset(seed=42, scenario_id="01_correlated_subquery")
|
| 77 |
+
print(obs.learned_hints) # surfaced from skill library
|
| 78 |
+
|
| 79 |
+
obs = env.step(SqlDriftEnv.action_list_tables())
|
| 80 |
+
obs = env.step(SqlDriftEnv.action_describe_table("users"))
|
| 81 |
+
obs = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM users"))
|
| 82 |
+
obs = env.step(SqlDriftEnv.action_submit_rewrite(
|
| 83 |
+
"SELECT u.*, COALESCE(c.n, 0) FROM users u "
|
| 84 |
+
"LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders GROUP BY 1) c "
|
| 85 |
+
"ON c.user_id = u.user_id"
|
| 86 |
+
))
|
| 87 |
+
print(obs.reward, obs.reward_components)
|
| 88 |
+
env.close()
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
End-to-end over an HTTP+WS OpenEnv server, see `SqlDriftEnv` in
|
| 92 |
+
`client.py` and the integration test suite in
|
| 93 |
+
`tests/integration/test_client_server.py` /
|
| 94 |
+
`tests/integration/test_state_no_leak.py`.
|
| 95 |
+
|
| 96 |
+
## Evaluation
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
uv run python -m training.eval \
|
| 100 |
+
--checkpoint base \
|
| 101 |
+
--scenarios 1-10 \
|
| 102 |
+
--seeds-per-scenario 5 \
|
| 103 |
+
--out outputs/evals/my_run
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Emits `report.md`, `per_episode.csv`, and `summary.json`. A baseline random
|
| 107 |
+
agent sweep is checked in at `outputs/evals/baseline_random_v1/report.md` as a
|
| 108 |
+
floor: mean reward ≈ **-2.05**, pass rate **0 %**.
|
| 109 |
+
|
| 110 |
+
## Training (GPU)
|
| 111 |
+
|
| 112 |
+
`training/grpo_train.py` contains the GRPO entrypoint used by the hackathon
|
| 113 |
+
training notebook: it builds the curriculum dataset, loads
|
| 114 |
+
`unsloth/Qwen3-1.7B-unsloth-bnb-4bit`, attaches PEFT LoRA adapters, and lets TRL's
|
| 115 |
+
`GRPOTrainer` drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`.
|
| 116 |
+
Run `training/sql_drift_grpo_training.ipynb` on a GPU runtime with
|
| 117 |
+
`SQL_DRIFT_ENV_URL` pointed at the deployed SQLDrift Space; the notebook writes
|
| 118 |
+
durable evidence plots to `training/evidence/`.
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
uv sync --extra train # installs trl, transformers, accelerate, unsloth
|
| 122 |
+
uv sync --extra evidence # matplotlib + pandas for utilities/plot_curves.py
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Repository layout
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
sql_drift_env/
|
| 129 |
+
├── models.py # Pydantic v2 action/observation/state
|
| 130 |
+
├── client.py # SqlDriftEnv EnvClient (/ws)
|
| 131 |
+
├── engine/
|
| 132 |
+
│ ├── runtime.py # private RuntimeEpisodeState
|
| 133 |
+
│ ├── drift.py # 4 DDL drift operators
|
| 134 |
+
│ ├── reward.py # SqlDriftRubric (6 child rubrics)
|
| 135 |
+
│ ├── verifier.py / profiler.py
|
| 136 |
+
├── scenarios/ # 10 hand-crafted scenario modules + registry
|
| 137 |
+
├── skill_library/ # pre-seeds + JSON store + Jaccard retrieval
|
| 138 |
+
├── actors/ # engineering_manager (changelog), dba_oracle
|
| 139 |
+
├── server/ # FastAPI app, Dockerfile, env wrapper class
|
| 140 |
+
├── training/ # config, prompt, random_agent, grpo_train, eval
|
| 141 |
+
├── utilities/ # env_loader, logger, plot_curves, run_training_job (HF Jobs), …
|
| 142 |
+
├── tests/ # 300+ unit + integration tests
|
| 143 |
+
├── outputs/evals/ # baseline eval artifacts
|
| 144 |
+
└── design/ # design docs (ignored by docker)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Design references
|
| 148 |
+
|
| 149 |
+
- `.cursor/plans/sqldrift_technical_spec_feb249d1.plan.md` — full 16-phase
|
| 150 |
+
technical spec (Rev-3).
|
| 151 |
+
- `design/CLAUDE.md` — agent coding guidelines.
|
| 152 |
+
|
| 153 |
+
## License
|
| 154 |
+
|
| 155 |
+
BSD-3-Clause (matches the OpenEnv project template).
|
__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLDrift — OpenEnv gym for SQL repair & optimization under drift.
|
| 2 |
+
|
| 3 |
+
The repo uses a flat top-level module layout (``models.py``,
|
| 4 |
+
``client.py``, ``engine/``, ``scenarios/``, ``skill_library/``,
|
| 5 |
+
``actors/``, ``server/``, ``training/``) because it is also run as a
|
| 6 |
+
FastAPI server that imports siblings absolutely (``from models import
|
| 7 |
+
…``). setuptools republishes ``.`` as the ``sql_drift_env`` package so
|
| 8 |
+
both import styles work at runtime, but eagerly re-exporting the flat
|
| 9 |
+
submodules from here would shadow the top-level ``import models`` /
|
| 10 |
+
``import client`` paths that every flat module relies on, and would
|
| 11 |
+
make the import order pytest-collection-sensitive.
|
| 12 |
+
|
| 13 |
+
The public API for agent code is therefore the flat modules themselves,
|
| 14 |
+
imported directly:
|
| 15 |
+
|
| 16 |
+
from client import SqlDriftEnv
|
| 17 |
+
from models import SqlDriftAction, SqlDriftObservation
|
| 18 |
+
from server import SqlDriftEnvironment
|
| 19 |
+
|
| 20 |
+
This mirrors the flat layout both on disk and at import time; the
|
| 21 |
+
``sql_drift_env`` namespace exists only so the wheel has a canonical
|
| 22 |
+
name and so third parties can depend on a stable version string.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
__version__ = "0.1.0"
|
| 28 |
+
|
| 29 |
+
__all__ = ["__version__"]
|
_cli.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console-script entrypoint that patches sys.path for the flat-import layout.
|
| 2 |
+
|
| 3 |
+
Problem
|
| 4 |
+
-------
|
| 5 |
+
The project uses a *flat layout* (``pyproject.toml`` maps ``.`` → the
|
| 6 |
+
``sql_drift_env`` package directory). That means every sibling module
|
| 7 |
+
(``models``, ``actors``, ``engine`` …) is imported as a plain top-level
|
| 8 |
+
name rather than via the ``sql_drift_env.`` prefix.
|
| 9 |
+
|
| 10 |
+
When the wheel is installed, those siblings land at
|
| 11 |
+
``site-packages/sql_drift_env/models.py`` etc., *not* at the top-level
|
| 12 |
+
``site-packages/`` directory. A naïve console-script that calls
|
| 13 |
+
``sql_drift_env.server.app:main`` would fail at ``from models import …``
|
| 14 |
+
before reaching any application logic.
|
| 15 |
+
|
| 16 |
+
Fix
|
| 17 |
+
---
|
| 18 |
+
Insert the installed package directory (``site-packages/sql_drift_env/``)
|
| 19 |
+
onto ``sys.path`` *before* importing anything from the server package.
|
| 20 |
+
This mirrors what Docker achieves via ``--app-dir /app/env`` / ``PYTHONPATH``,
|
| 21 |
+
but works for any installed-wheel invocation without requiring a wrapper
|
| 22 |
+
script or Docker.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
# __file__ resolves to site-packages/sql_drift_env/_cli.py after
|
| 33 |
+
# installation, so its parent IS the directory that contains models.py,
|
| 34 |
+
# server/, actors/, etc.
|
| 35 |
+
_pkg_dir = os.path.dirname(os.path.abspath(__file__))
|
| 36 |
+
if _pkg_dir not in sys.path:
|
| 37 |
+
sys.path.insert(0, _pkg_dir)
|
| 38 |
+
|
| 39 |
+
# Import lazily so the sys.path fix takes effect before any flat import
|
| 40 |
+
# in server/app.py or its transitive dependencies is attempted.
|
| 41 |
+
from server.app import main as _server_main
|
| 42 |
+
|
| 43 |
+
_server_main()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|
actors/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Non-agent actors (Engineering Manager, DBA Oracle) — P4/P9."""
|
actors/dba_oracle.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DBA Oracle: per-scenario escalating expert guidance.
|
| 2 |
+
|
| 3 |
+
Feature-flagged (off by default). When ``enable_dba_oracle=True`` is set
|
| 4 |
+
at ``reset(...)`` or ``SQL_DRIFT_ENABLE_DBA_ORACLE=1`` is exported in the
|
| 5 |
+
environment, the ``consult_dba`` tool becomes available; three hint
|
| 6 |
+
tiers per scenario are shipped here, covering all 10 scenarios (6 static
|
| 7 |
+
+ 4 drift).
|
| 8 |
+
|
| 9 |
+
Hints escalate in specificity: tier 1 names the likely failure mode and
|
| 10 |
+
the diagnostic to confirm it, tier 2 gives the rewrite invariant, and
|
| 11 |
+
tier 3 is a near-spoiler SQL skeleton. The rubric penalizes each consult
|
| 12 |
+
(consultation rubric), so the agent only wins by consulting if the hint net-reduces
|
| 13 |
+
downstream steps.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from typing import Final
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Per-scenario 3-tier hint tables
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_HINTS: Final[dict[str, tuple[str, str, str]]] = {
|
| 27 |
+
"01_correlated_subquery": (
|
| 28 |
+
"[DBA tier 1] The expensive shape is a projection-time correlated subquery: "
|
| 29 |
+
"a COUNT over `orders` for every `users` row. Confirm by spotting "
|
| 30 |
+
"`SELECT COUNT(*) ... WHERE o.user_id = u.id` in the SELECT list or by "
|
| 31 |
+
"checking EXPLAIN for repeated dependent work. Preserve one output row per user.",
|
| 32 |
+
"[DBA tier 2] Aggregate fulfilled orders once by `user_id`, then LEFT JOIN that "
|
| 33 |
+
"small result to `users`. Keep the join outer and wrap the count with "
|
| 34 |
+
"`COALESCE(..., 0)` so users with no fulfilled orders stay in the result.",
|
| 35 |
+
"[DBA tier 3] Use `SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders "
|
| 36 |
+
"FROM users u LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders WHERE "
|
| 37 |
+
"status = 'fulfilled' GROUP BY user_id) c ON c.user_id = u.id ORDER BY u.id`. "
|
| 38 |
+
"Validate the row count equals the number of users.",
|
| 39 |
+
),
|
| 40 |
+
"02_select_star_join": (
|
| 41 |
+
"[DBA tier 1] The waste is over-projection: the inner three-way join uses "
|
| 42 |
+
"`SELECT *`, including wide product text and order metadata, while the outer "
|
| 43 |
+
"query keeps only `order_id`, product `name`, and `qty`.",
|
| 44 |
+
"[DBA tier 2] Inline the join and project exactly `oi.order_id`, `p.name`, and "
|
| 45 |
+
"`oi.qty`. Keep the products and orders joins plus the filters "
|
| 46 |
+
"`p.category = 'books'` and `oi.qty >= 2`; the wrapper exists only to hide "
|
| 47 |
+
"the star projection.",
|
| 48 |
+
"[DBA tier 3] Rewrite as `SELECT oi.order_id, p.name, oi.qty FROM order_items oi "
|
| 49 |
+
"JOIN products p ON p.id = oi.product_id JOIN orders o ON o.id = oi.order_id "
|
| 50 |
+
"WHERE p.category = 'books' AND oi.qty >= 2 ORDER BY oi.order_id, p.name`.",
|
| 51 |
+
),
|
| 52 |
+
"03_cartesian_join": (
|
| 53 |
+
"[DBA tier 1] This is an accidental cartesian join. `FROM events e, tenants t` "
|
| 54 |
+
"combined with `t.id = e.tenant_id + 0` prevents the optimizer from seeing a "
|
| 55 |
+
"clean tenant-key join early.",
|
| 56 |
+
"[DBA tier 2] Turn the comma join into an explicit equijoin on the tenant key. "
|
| 57 |
+
"Move only `t.id = e.tenant_id` into `ON`; keep the severity filter in `WHERE` "
|
| 58 |
+
"and preserve grouping by tenant tier.",
|
| 59 |
+
"[DBA tier 3] Use `SELECT t.tier, COUNT(*) AS n FROM events e JOIN tenants t "
|
| 60 |
+
"ON t.id = e.tenant_id WHERE e.severity IN ('error', 'critical') GROUP BY "
|
| 61 |
+
"t.tier ORDER BY t.tier`. Avoid arithmetic on the join key.",
|
| 62 |
+
),
|
| 63 |
+
"04_distinct_groupby": (
|
| 64 |
+
"[DBA tier 1] The duplicate-removal work is redundant. `GROUP BY session_id, path` "
|
| 65 |
+
"already emits one row per `(session_id, path)` pair, so a leading `DISTINCT` "
|
| 66 |
+
"adds a second deduplication pass over grouped rows.",
|
| 67 |
+
"[DBA tier 2] Do not introduce a CTE or change the aggregation grain. Remove only "
|
| 68 |
+
"`DISTINCT`; keep `COUNT(*) AS hits`, the same GROUP BY keys, and the same "
|
| 69 |
+
"ordering so row identity and sort order stay stable.",
|
| 70 |
+
"[DBA tier 3] The target shape is `SELECT session_id, path, COUNT(*) AS hits "
|
| 71 |
+
"FROM pageviews GROUP BY session_id, path ORDER BY session_id, path`. Validate "
|
| 72 |
+
"against the baseline result before comparing runtime.",
|
| 73 |
+
),
|
| 74 |
+
"05_nested_subquery": (
|
| 75 |
+
"[DBA tier 1] The nested `IN` clauses express a semi-join: authors who wrote "
|
| 76 |
+
"comments on published articles. The important identity is `comments.author_id`, "
|
| 77 |
+
"not `articles.author_id`.",
|
| 78 |
+
"[DBA tier 2] Flatten to `authors -> comments -> articles`, filter "
|
| 79 |
+
"`articles.status = 'published'`, and select distinct author display names. "
|
| 80 |
+
"`DISTINCT` is required here because one author can have many qualifying comments.",
|
| 81 |
+
"[DBA tier 3] Use `SELECT DISTINCT a.display_name FROM authors a JOIN comments c "
|
| 82 |
+
"ON c.author_id = a.id JOIN articles ar ON ar.id = c.article_id WHERE "
|
| 83 |
+
"ar.status = 'published' ORDER BY a.display_name`.",
|
| 84 |
+
),
|
| 85 |
+
"06_having_as_where": (
|
| 86 |
+
"[DBA tier 1] `status = 'fulfilled'` is a row-level predicate sitting in HAVING, "
|
| 87 |
+
"so the engine groups every status first and discards most groups afterward. "
|
| 88 |
+
"Only `SUM(amount_cents) >= 100000` truly belongs after aggregation.",
|
| 89 |
+
"[DBA tier 2] Move the status filter into `WHERE` before the GROUP BY. Keep "
|
| 90 |
+
"`status` in the projection and grouping to preserve the result shape, then "
|
| 91 |
+
"leave the aggregate threshold in HAVING.",
|
| 92 |
+
"[DBA tier 3] Use `SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
|
| 93 |
+
"FROM orders WHERE status = 'fulfilled' GROUP BY tenant_id, status HAVING "
|
| 94 |
+
"SUM(amount_cents) >= 100000 ORDER BY tenant_id`.",
|
| 95 |
+
),
|
| 96 |
+
"07_drift_column_rename": (
|
| 97 |
+
"[DBA tier 1] If the old aggregation now fails with an unknown `user_id`, this is "
|
| 98 |
+
"schema drift rather than a performance issue. Read the changelog or describe "
|
| 99 |
+
"`orders`; `users.id` is unchanged, but the order-owner column moved.",
|
| 100 |
+
"[DBA tier 2] Replace every reference to `orders.user_id` with `orders.account_id` "
|
| 101 |
+
"in SELECT, GROUP BY, JOIN, and ORDER BY positions. Do not change the aggregate "
|
| 102 |
+
"logic; the rename preserves row semantics.",
|
| 103 |
+
"[DBA tier 3] Submit `SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) "
|
| 104 |
+
"AS total FROM orders GROUP BY account_id ORDER BY account_id`. Validate that "
|
| 105 |
+
"counts and totals match the pre-drift business result.",
|
| 106 |
+
),
|
| 107 |
+
"08_drift_date_format": (
|
| 108 |
+
"[DBA tier 1] The `events.ts` identifier still exists, but its type changed from "
|
| 109 |
+
"ISO text to BIGINT epoch milliseconds. A string date predicate can parse or "
|
| 110 |
+
"compare incorrectly; confirm with `describe_table('events')` and samples.",
|
| 111 |
+
"[DBA tier 2] Keep the same half-open UTC day window, but express both bounds as "
|
| 112 |
+
"epoch-ms integers. For 2026-04-21T00:00:00Z through the next midnight, use "
|
| 113 |
+
"`1776729600000 <= ts < 1776816000000`.",
|
| 114 |
+
"[DBA tier 3] Use `SELECT kind, COUNT(*) AS n FROM events WHERE ts >= "
|
| 115 |
+
"1776729600000 AND ts < 1776816000000 GROUP BY kind ORDER BY kind`. Do not quote "
|
| 116 |
+
"the bounds; they must be numeric comparisons against the BIGINT column.",
|
| 117 |
+
),
|
| 118 |
+
"09_drift_enum_rule": (
|
| 119 |
+
"[DBA tier 1] A formerly valid equality on `status = 'active'` now silently loses "
|
| 120 |
+
"rows because the business state was split into multiple stored labels. Sample "
|
| 121 |
+
"`tenants.status` before assuming the old lowercase value still exists.",
|
| 122 |
+
'[DBA tier 2] Preserve the business meaning "active tenants" by filtering on the '
|
| 123 |
+
"union of replacement labels. Keep the same grouping by tier and ordering; only "
|
| 124 |
+
"the status predicate changes.",
|
| 125 |
+
"[DBA tier 3] Use `SELECT tier, COUNT(*) AS n FROM tenants WHERE status IN "
|
| 126 |
+
"('ACTIVE', 'ACTIVE_V2') GROUP BY tier ORDER BY tier`. Avoid `LOWER(status) = "
|
| 127 |
+
"'active'`; it misses `ACTIVE_V2`.",
|
| 128 |
+
),
|
| 129 |
+
"10_drift_field_deprecation": (
|
| 130 |
+
"[DBA tier 1] The inline `posts.author_name` column was normalized away. Describe "
|
| 131 |
+
"`posts` and list tables: you should see `posts.users_id` plus a new `users` "
|
| 132 |
+
"lookup carrying the human-readable name.",
|
| 133 |
+
"[DBA tier 2] Join `posts` to `users` through the new FK, group by `u.full_name`, "
|
| 134 |
+
"and alias it back to `author_name` so the result keeps the old report shape. "
|
| 135 |
+
"The post count still comes from `posts`.",
|
| 136 |
+
"[DBA tier 3] Use `SELECT u.full_name AS author_name, COUNT(*) AS n_posts FROM "
|
| 137 |
+
"posts p JOIN users u ON u.id = p.users_id GROUP BY u.full_name ORDER BY "
|
| 138 |
+
"u.full_name`.",
|
| 139 |
+
),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# Public API
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def is_enabled(
|
| 149 |
+
reset_flag: bool | None = None, *, env_var: str = "SQL_DRIFT_ENABLE_DBA_ORACLE"
|
| 150 |
+
) -> bool:
|
| 151 |
+
"""Resolve the feature flag from (reset kwarg, env var, default-off)."""
|
| 152 |
+
if reset_flag is not None:
|
| 153 |
+
return bool(reset_flag)
|
| 154 |
+
raw = os.environ.get(env_var, "").strip().lower()
|
| 155 |
+
return raw in ("1", "true", "yes", "on")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_hint(scenario_id: str, tier: int) -> str:
|
| 159 |
+
"""Return the hint for ``(scenario_id, tier)``; clamps tier to [1, 3].
|
| 160 |
+
|
| 161 |
+
Raises :class:`KeyError` on unknown scenario so tests can detect when
|
| 162 |
+
a new scenario was added without a hint table entry.
|
| 163 |
+
"""
|
| 164 |
+
if scenario_id not in _HINTS:
|
| 165 |
+
raise KeyError(f"no DBA hints for scenario_id={scenario_id!r}; known: {sorted(_HINTS)}")
|
| 166 |
+
tier = max(1, min(3, int(tier)))
|
| 167 |
+
return _HINTS[scenario_id][tier - 1]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def has_hints(scenario_id: str) -> bool:
|
| 171 |
+
return scenario_id in _HINTS
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def known_scenarios() -> frozenset[str]:
|
| 175 |
+
return frozenset(_HINTS)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
__all__ = [
|
| 179 |
+
"get_hint",
|
| 180 |
+
"has_hints",
|
| 181 |
+
"is_enabled",
|
| 182 |
+
"known_scenarios",
|
| 183 |
+
]
|
actors/engineering_manager.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Engineering-manager narrator — authors deterministic drift runbooks.
|
| 2 |
+
|
| 3 |
+
Consumed by the ``read_changelog`` tool. Output is deterministic per
|
| 4 |
+
``DriftConfig`` so tests can assert it character-for-character and the
|
| 5 |
+
agent can learn to parse drift kinds from the text.
|
| 6 |
+
|
| 7 |
+
Entries are concise migration notes: what changed, why the old query
|
| 8 |
+
breaks, how to adapt, and what to validate before submission.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import TYPE_CHECKING
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from scenarios.base import DriftConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _sentence(prefix: str, body: str) -> str:
|
| 20 |
+
return f"[changelog] {prefix} {body}".strip()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def author_changelog(drift_config: DriftConfig) -> str:
|
| 24 |
+
kind = drift_config.kind
|
| 25 |
+
p = drift_config.payload
|
| 26 |
+
if kind == "column_rename":
|
| 27 |
+
return _sentence(
|
| 28 |
+
"Schema change -",
|
| 29 |
+
(
|
| 30 |
+
f"`{p['table']}.{p['old']}` was renamed to `{p['table']}.{p['new']}` "
|
| 31 |
+
"as part of a naming normalization. Impact: queries that SELECT, JOIN, "
|
| 32 |
+
"GROUP BY, ORDER BY, or filter on the old identifier now fail at bind "
|
| 33 |
+
"time. Migration: use the new column everywhere the old one appeared; "
|
| 34 |
+
"related table primary keys are unchanged. Validate that counts and "
|
| 35 |
+
"aggregates are unchanged after the rename."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
if kind == "date_format":
|
| 39 |
+
return _sentence(
|
| 40 |
+
"Format change -",
|
| 41 |
+
(
|
| 42 |
+
f"`{p['table']}.{p['col']}` no longer stores ISO-8601 strings; it now "
|
| 43 |
+
"stores BIGINT milliseconds since the Unix epoch. Impact: quoted "
|
| 44 |
+
"timestamp literals and text comparisons no longer express the right "
|
| 45 |
+
"predicate. Migration: keep the same half-open time window, but compare "
|
| 46 |
+
"against numeric epoch-ms bounds. Validate with describe_table and a "
|
| 47 |
+
"small sample before submitting."
|
| 48 |
+
),
|
| 49 |
+
)
|
| 50 |
+
if kind == "enum_rule":
|
| 51 |
+
new_values = ", ".join(f"`{v}`" for v in p["new_values"])
|
| 52 |
+
return _sentence(
|
| 53 |
+
"Business-rule change -",
|
| 54 |
+
(
|
| 55 |
+
f"`{p['table']}.{p['col']}` value `{p['old_value']}` has been split "
|
| 56 |
+
f"into {new_values}. Impact: equality predicates on the old value "
|
| 57 |
+
"silently undercount after the deploy. Migration: replace the single "
|
| 58 |
+
"value predicate with an IN predicate over every replacement label. "
|
| 59 |
+
"Validate by sampling the enum distribution and preserving the existing "
|
| 60 |
+
"GROUP BY and projection."
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
if kind == "field_deprecation":
|
| 64 |
+
orig_t, orig_c = p["orig"]
|
| 65 |
+
lt, lid, lname = p["lookup"]
|
| 66 |
+
fk_col = f"{lt}_{lid}"
|
| 67 |
+
return _sentence(
|
| 68 |
+
"Deprecation -",
|
| 69 |
+
(
|
| 70 |
+
f"`{orig_t}.{orig_c}` was deprecated and replaced by `{orig_t}.{fk_col}` "
|
| 71 |
+
f"pointing at `{lt}.{lid}`; the display value now lives on `{lt}.{lname}`. "
|
| 72 |
+
"Impact: projections or groups over the old inline string fail after "
|
| 73 |
+
"drift. Migration: join through the lookup table, project the display "
|
| 74 |
+
"column, and keep the old output alias if callers expect it. Validate "
|
| 75 |
+
"row counts after the join."
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
raise ValueError(f"unknown drift kind={kind!r}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
__all__ = ["author_changelog"]
|
client.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLDrift ``EnvClient`` — tool-aware payload constructors + response parser.
|
| 2 |
+
|
| 3 |
+
Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts,
|
| 4 |
+
notebook exploration, and integration tests all use the same WS-backed
|
| 5 |
+
session semantics. Stateful episodes MUST go through the ``/ws`` channel
|
| 6 |
+
(HTTP ``/step`` is stateless: one fresh env per request).
|
| 7 |
+
|
| 8 |
+
Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.)
|
| 9 |
+
hide the discriminated-union boilerplate so agent code reads naturally::
|
| 10 |
+
|
| 11 |
+
env = SqlDriftEnv(base_url="http://localhost:8000").sync()
|
| 12 |
+
with env:
|
| 13 |
+
r = env.reset(seed=42, scenario_id="03_cartesian_join")
|
| 14 |
+
r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events"))
|
| 15 |
+
...
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
from openenv.core.client_types import StepResult
|
| 23 |
+
from openenv.core.env_client import EnvClient
|
| 24 |
+
|
| 25 |
+
from models import (
|
| 26 |
+
ConsultDBAPayload,
|
| 27 |
+
DescribeTablePayload,
|
| 28 |
+
ExplainQueryPayload,
|
| 29 |
+
ListTablesPayload,
|
| 30 |
+
ReadChangelogPayload,
|
| 31 |
+
RunQueryPayload,
|
| 32 |
+
SampleRowsPayload,
|
| 33 |
+
SqlDriftAction,
|
| 34 |
+
SqlDriftObservation,
|
| 35 |
+
SqlDriftState,
|
| 36 |
+
SubmitRewritePayload,
|
| 37 |
+
ToolName,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
|
| 42 |
+
"""Tool-aware client for the SQLDrift OpenEnv environment."""
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------------
|
| 45 |
+
# EnvClient ABC implementations
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]:
|
| 49 |
+
return action.model_dump(mode="json")
|
| 50 |
+
|
| 51 |
+
def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]:
|
| 52 |
+
obs_data = payload.get("observation", {})
|
| 53 |
+
observation = SqlDriftObservation.model_validate(obs_data)
|
| 54 |
+
# Base transport strips reward + done off the observation dict — we
|
| 55 |
+
# re-populate them so the agent can read straight off `.observation`.
|
| 56 |
+
reward = payload.get("reward")
|
| 57 |
+
done = bool(payload.get("done", False))
|
| 58 |
+
observation.reward = reward
|
| 59 |
+
observation.done = done
|
| 60 |
+
return StepResult(observation=observation, reward=reward, done=done)
|
| 61 |
+
|
| 62 |
+
def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState:
|
| 63 |
+
return SqlDriftState.model_validate(payload)
|
| 64 |
+
|
| 65 |
+
# ------------------------------------------------------------------
|
| 66 |
+
# Action factories — one per tool, accepting only the args that tool
|
| 67 |
+
# cares about; payload.kind is filled in automatically.
|
| 68 |
+
# ------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def action_list_tables() -> SqlDriftAction:
|
| 72 |
+
return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def action_describe_table(table: str) -> SqlDriftAction:
|
| 76 |
+
return SqlDriftAction(
|
| 77 |
+
tool=ToolName.DESCRIBE_TABLE,
|
| 78 |
+
payload=DescribeTablePayload(table=table),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction:
|
| 83 |
+
return SqlDriftAction(
|
| 84 |
+
tool=ToolName.SAMPLE_ROWS,
|
| 85 |
+
payload=SampleRowsPayload(table=table, limit=limit),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def action_run_query(sql: str) -> SqlDriftAction:
|
| 90 |
+
return SqlDriftAction(
|
| 91 |
+
tool=ToolName.RUN_QUERY,
|
| 92 |
+
payload=RunQueryPayload(sql=sql),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def action_explain_query(sql: str) -> SqlDriftAction:
|
| 97 |
+
return SqlDriftAction(
|
| 98 |
+
tool=ToolName.EXPLAIN_QUERY,
|
| 99 |
+
payload=ExplainQueryPayload(sql=sql),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def action_read_changelog() -> SqlDriftAction:
|
| 104 |
+
return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload())
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def action_submit_rewrite(sql: str) -> SqlDriftAction:
|
| 108 |
+
return SqlDriftAction(
|
| 109 |
+
tool=ToolName.SUBMIT_REWRITE,
|
| 110 |
+
payload=SubmitRewritePayload(sql=sql),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def action_consult_dba(question: str) -> SqlDriftAction:
|
| 115 |
+
return SqlDriftAction(
|
| 116 |
+
tool=ToolName.CONSULT_DBA,
|
| 117 |
+
payload=ConsultDBAPayload(question=question),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
__all__ = ["SqlDriftEnv"]
|
engine/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic verifier / profiler / drift / reward engine."""
|
engine/drift.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Drift engine: four atomic, idempotent DDL operations.
|
| 2 |
+
|
| 3 |
+
Each apply_* function mutates ``conn`` in place inside a DuckDB
|
| 4 |
+
``BEGIN; ... COMMIT`` pair and returns a machine-readable changelog
|
| 5 |
+
string. Humans consume the string via the :class:`read_changelog`
|
| 6 |
+
tool; the rubric consults a separate drift-acknowledgement flag on the
|
| 7 |
+
runtime state, not the string itself.
|
| 8 |
+
|
| 9 |
+
Idempotency is enforced via a post-condition schema probe: once the drift
|
| 10 |
+
has been applied (the target column / enum value is in the expected
|
| 11 |
+
post-state), a second call short-circuits with the same changelog string.
|
| 12 |
+
This matters because the environment's drift-trigger check runs every
|
| 13 |
+
step and needs to be safe to retry.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
import duckdb
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# DuckDB auto-commits DDL and forbids mixing multi-statement transactions
|
| 25 |
+
# with schema alterations across commit boundaries. Each drift operation
|
| 26 |
+
# therefore executes its statements sequentially on the default
|
| 27 |
+
# auto-commit connection; individual DML statements (UPDATEs) are
|
| 28 |
+
# internally atomic at the statement level, which is sufficient for the
|
| 29 |
+
# fixture mutation the env needs. If a drift operation raises mid-way we
|
| 30 |
+
# tear down and re-seed the DuckDB via ScenarioSpec.materialize — there's
|
| 31 |
+
# no long-lived on-disk state to roll back.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _table_columns(conn: duckdb.DuckDBPyConnection, table: str) -> list[str]:
|
| 35 |
+
rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall()
|
| 36 |
+
# PRAGMA table_info returns (cid, name, type, notnull, dflt_value, pk)
|
| 37 |
+
return [r[1] for r in rows]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _table_exists(conn: duckdb.DuckDBPyConnection, table: str) -> bool:
|
| 41 |
+
rows = conn.execute(
|
| 42 |
+
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", [table]
|
| 43 |
+
).fetchone()
|
| 44 |
+
return bool(rows and rows[0])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# =============================================================================
|
| 48 |
+
# Column rename
|
| 49 |
+
# =============================================================================
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_column_rename(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
|
| 53 |
+
"""``{"table": str, "old": str, "new": str}``."""
|
| 54 |
+
table = payload["table"]
|
| 55 |
+
old = payload["old"]
|
| 56 |
+
new = payload["new"]
|
| 57 |
+
|
| 58 |
+
cols = _table_columns(conn, table)
|
| 59 |
+
if new in cols and old not in cols:
|
| 60 |
+
return f"rename_already_applied:{table}.{old}->{new}"
|
| 61 |
+
if old not in cols:
|
| 62 |
+
raise ValueError(f"column_rename: {table}.{old} missing (cols={cols})")
|
| 63 |
+
|
| 64 |
+
conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{old}" TO "{new}"')
|
| 65 |
+
return f"rename:{table}.{old}->{new}"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# =============================================================================
|
| 69 |
+
# Date format change (iso_string → epoch_ms)
|
| 70 |
+
# =============================================================================
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def apply_date_format_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
|
| 74 |
+
"""``{"table": str, "col": str, "from": "iso_string", "to": "epoch_ms"}``.
|
| 75 |
+
|
| 76 |
+
Only the one direction is supported for now; the payload still carries
|
| 77 |
+
from/to for forward-compatibility and audit.
|
| 78 |
+
"""
|
| 79 |
+
table = payload["table"]
|
| 80 |
+
col = payload["col"]
|
| 81 |
+
from_fmt = payload.get("from", "iso_string")
|
| 82 |
+
to_fmt = payload.get("to", "epoch_ms")
|
| 83 |
+
if (from_fmt, to_fmt) != ("iso_string", "epoch_ms"):
|
| 84 |
+
raise NotImplementedError(
|
| 85 |
+
f"date_format_change only supports iso_string→epoch_ms, got {from_fmt}→{to_fmt}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
cols = _table_columns(conn, table)
|
| 89 |
+
# Idempotent: once column is BIGINT, consider it applied.
|
| 90 |
+
type_row = conn.execute(
|
| 91 |
+
"SELECT data_type FROM information_schema.columns WHERE table_name = ? AND column_name = ?",
|
| 92 |
+
[table, col],
|
| 93 |
+
).fetchone()
|
| 94 |
+
if type_row is None:
|
| 95 |
+
raise ValueError(f"date_format_change: {table}.{col} missing (cols={cols})")
|
| 96 |
+
if "BIGINT" in type_row[0].upper() or "INT" in type_row[0].upper():
|
| 97 |
+
return f"date_format_already_applied:{table}.{col}"
|
| 98 |
+
|
| 99 |
+
tmp = f"{col}_epoch_ms"
|
| 100 |
+
conn.execute(f'ALTER TABLE "{table}" ADD COLUMN "{tmp}" BIGINT')
|
| 101 |
+
conn.execute(
|
| 102 |
+
f'UPDATE "{table}" SET "{tmp}" = '
|
| 103 |
+
f'CAST(EXTRACT(EPOCH FROM CAST("{col}" AS TIMESTAMP)) * 1000 AS BIGINT)'
|
| 104 |
+
)
|
| 105 |
+
conn.execute(f'ALTER TABLE "{table}" DROP COLUMN "{col}"')
|
| 106 |
+
conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{tmp}" TO "{col}"')
|
| 107 |
+
|
| 108 |
+
return f"date_format:{table}.{col}:iso_string->epoch_ms"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# Enum rule change (split `old_value` into N new values)
|
| 113 |
+
# =============================================================================
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def apply_enum_rule_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
|
| 117 |
+
"""``{"table": str, "col": str, "old_value": str, "new_values": list[str]}``.
|
| 118 |
+
|
| 119 |
+
Rows holding ``old_value`` are re-distributed deterministically into
|
| 120 |
+
``new_values`` (round-robin by rowid) so the split is reproducible.
|
| 121 |
+
"""
|
| 122 |
+
table = payload["table"]
|
| 123 |
+
col = payload["col"]
|
| 124 |
+
old_value = payload["old_value"]
|
| 125 |
+
new_values: list[str] = list(payload["new_values"])
|
| 126 |
+
if not new_values:
|
| 127 |
+
raise ValueError("enum_rule_change: new_values must be non-empty")
|
| 128 |
+
|
| 129 |
+
count_row = conn.execute(
|
| 130 |
+
f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" = ?', [old_value]
|
| 131 |
+
).fetchone()
|
| 132 |
+
count_old = count_row[0] if count_row is not None else 0
|
| 133 |
+
# Idempotent: if old_value has already been drained AND any of the
|
| 134 |
+
# new_values is present, treat as applied.
|
| 135 |
+
if count_old == 0:
|
| 136 |
+
has_new_row = conn.execute(
|
| 137 |
+
f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" IN ({",".join("?" * len(new_values))})',
|
| 138 |
+
new_values,
|
| 139 |
+
).fetchone()
|
| 140 |
+
has_new = has_new_row[0] if has_new_row is not None else 0
|
| 141 |
+
if has_new > 0:
|
| 142 |
+
return f"enum_rule_already_applied:{table}.{col}:{old_value}->{new_values}"
|
| 143 |
+
|
| 144 |
+
# Deterministic split by rowid mod N.
|
| 145 |
+
case_branches = " ".join(
|
| 146 |
+
f"WHEN mod(rid, {len(new_values)}) = {i} THEN '{v}'" for i, v in enumerate(new_values)
|
| 147 |
+
)
|
| 148 |
+
conn.execute(
|
| 149 |
+
f"CREATE TEMP TABLE _enum_remap AS "
|
| 150 |
+
f"SELECT rowid AS rid, "
|
| 151 |
+
f"CASE {case_branches} END AS new_val "
|
| 152 |
+
f'FROM "{table}" WHERE "{col}" = ?',
|
| 153 |
+
[old_value],
|
| 154 |
+
)
|
| 155 |
+
conn.execute(
|
| 156 |
+
f'UPDATE "{table}" SET "{col}" = _enum_remap.new_val '
|
| 157 |
+
f'FROM _enum_remap WHERE _enum_remap.rid = "{table}".rowid'
|
| 158 |
+
)
|
| 159 |
+
conn.execute("DROP TABLE _enum_remap")
|
| 160 |
+
|
| 161 |
+
return f"enum_rule:{table}.{col}:{old_value}->{'+'.join(new_values)}"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# =============================================================================
|
| 165 |
+
# Field deprecation (replace inline string col with FK lookup)
|
| 166 |
+
# =============================================================================
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def apply_field_deprecation(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
|
| 170 |
+
"""``{"orig": (table, col), "lookup": (table, id_col, name_col)}``.
|
| 171 |
+
|
| 172 |
+
- Creates the lookup table (if missing) and seeds it with distinct values
|
| 173 |
+
observed on ``orig.col``.
|
| 174 |
+
- Adds ``orig.<lookup_id>`` with a FK-style backfill.
|
| 175 |
+
- Drops ``orig.col``.
|
| 176 |
+
"""
|
| 177 |
+
orig_table, orig_col = payload["orig"]
|
| 178 |
+
lookup_table, lookup_id_col, lookup_name_col = payload["lookup"]
|
| 179 |
+
new_fk_col = f"{lookup_table}_{lookup_id_col}" # e.g. "users_id"
|
| 180 |
+
|
| 181 |
+
orig_cols = _table_columns(conn, orig_table)
|
| 182 |
+
if orig_col not in orig_cols and new_fk_col in orig_cols:
|
| 183 |
+
return f"field_deprecation_already_applied:{orig_table}.{orig_col}"
|
| 184 |
+
if orig_col not in orig_cols:
|
| 185 |
+
raise ValueError(f"field_deprecation: {orig_table}.{orig_col} missing (cols={orig_cols})")
|
| 186 |
+
|
| 187 |
+
if not _table_exists(conn, lookup_table):
|
| 188 |
+
conn.execute(
|
| 189 |
+
f'CREATE TABLE "{lookup_table}" ('
|
| 190 |
+
f' "{lookup_id_col}" BIGINT PRIMARY KEY,'
|
| 191 |
+
f' "{lookup_name_col}" VARCHAR'
|
| 192 |
+
");"
|
| 193 |
+
)
|
| 194 |
+
conn.execute(
|
| 195 |
+
f'INSERT INTO "{lookup_table}" ("{lookup_id_col}", "{lookup_name_col}") '
|
| 196 |
+
f"SELECT ROW_NUMBER() OVER (ORDER BY v) + "
|
| 197 |
+
f'COALESCE((SELECT MAX("{lookup_id_col}") FROM "{lookup_table}"), 0), v '
|
| 198 |
+
f'FROM (SELECT DISTINCT "{orig_col}" AS v FROM "{orig_table}") '
|
| 199 |
+
f"WHERE v IS NOT NULL "
|
| 200 |
+
f' AND v NOT IN (SELECT "{lookup_name_col}" FROM "{lookup_table}");'
|
| 201 |
+
)
|
| 202 |
+
conn.execute(f'ALTER TABLE "{orig_table}" ADD COLUMN "{new_fk_col}" BIGINT')
|
| 203 |
+
conn.execute(
|
| 204 |
+
f'UPDATE "{orig_table}" SET "{new_fk_col}" = lookup."{lookup_id_col}" '
|
| 205 |
+
f'FROM "{lookup_table}" lookup '
|
| 206 |
+
f'WHERE lookup."{lookup_name_col}" = "{orig_table}"."{orig_col}"'
|
| 207 |
+
)
|
| 208 |
+
conn.execute(f'ALTER TABLE "{orig_table}" DROP COLUMN "{orig_col}"')
|
| 209 |
+
|
| 210 |
+
return (
|
| 211 |
+
f"field_deprecation:{orig_table}.{orig_col}->"
|
| 212 |
+
f"{orig_table}.{new_fk_col}→{lookup_table}.{lookup_name_col}"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# =============================================================================
|
| 217 |
+
# Dispatcher
|
| 218 |
+
# =============================================================================
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
DRIFT_HANDLERS = {
|
| 222 |
+
"column_rename": apply_column_rename,
|
| 223 |
+
"date_format": apply_date_format_change,
|
| 224 |
+
"enum_rule": apply_enum_rule_change,
|
| 225 |
+
"field_deprecation": apply_field_deprecation,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def apply_drift(conn: duckdb.DuckDBPyConnection, kind: str, payload: dict[str, Any]) -> str:
|
| 230 |
+
if kind not in DRIFT_HANDLERS:
|
| 231 |
+
raise ValueError(f"unknown drift kind={kind!r}; known: {sorted(DRIFT_HANDLERS)}")
|
| 232 |
+
return DRIFT_HANDLERS[kind](conn, payload)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
__all__ = [
|
| 236 |
+
"DRIFT_HANDLERS",
|
| 237 |
+
"apply_column_rename",
|
| 238 |
+
"apply_date_format_change",
|
| 239 |
+
"apply_drift",
|
| 240 |
+
"apply_enum_rule_change",
|
| 241 |
+
"apply_field_deprecation",
|
| 242 |
+
]
|
engine/profiler.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query profiling utilities.
|
| 2 |
+
|
| 3 |
+
A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer.
|
| 4 |
+
|
| 5 |
+
* :func:`execute_once_timed` runs a statement exactly once, enforcing a
|
| 6 |
+
hard ``timeout_s`` wall-clock budget. It is the single entry point used
|
| 7 |
+
by the env for agent-provided SQL so the documented query timeout
|
| 8 |
+
cannot be bypassed. An optional ``max_rows`` caps result-set
|
| 9 |
+
materialization — the fetch is aborted as soon as more than
|
| 10 |
+
``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot
|
| 11 |
+
drive the server OOM before the caller's size check runs.
|
| 12 |
+
* :func:`execute_hash_timed` executes a statement once and hashes its full
|
| 13 |
+
result incrementally via ``fetchmany`` so correctness checks do not have
|
| 14 |
+
to materialize the full row set in Python memory.
|
| 15 |
+
* :func:`median_of_3_warm_ms` performs one untimed warm-up then three
|
| 16 |
+
timed runs and returns the median milliseconds. Used by scenario
|
| 17 |
+
materialization to publish a stable baseline runtime.
|
| 18 |
+
|
| 19 |
+
Both helpers raise :class:`TimeoutError` when a single run exceeds the
|
| 20 |
+
budget; ``duckdb.Error`` propagates unchanged to the caller.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import contextlib
|
| 26 |
+
import os
|
| 27 |
+
import threading
|
| 28 |
+
import time
|
| 29 |
+
from collections.abc import Callable, Iterator
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import TYPE_CHECKING, Any, cast
|
| 32 |
+
|
| 33 |
+
from engine.verifier import canonical_row_hash
|
| 34 |
+
from utilities.logger import get_module_logger
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING:
|
| 37 |
+
import duckdb
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
DEFAULT_TIMEOUT_S: float = 2.0
|
| 41 |
+
INTERRUPT_GRACE_S: float = 0.25
|
| 42 |
+
# Maximum number of watchdog escalations (leaked threads) tolerated before
|
| 43 |
+
# logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS.
|
| 44 |
+
MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3"))
|
| 45 |
+
|
| 46 |
+
_LOG = get_module_logger(__name__)
|
| 47 |
+
_FETCH_CHUNK_ROWS = 1024
|
| 48 |
+
|
| 49 |
+
# Module-level counter — incremented each time a watchdog thread survives
|
| 50 |
+
# interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe
|
| 51 |
+
# via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count().
|
| 52 |
+
_watchdog_leak_lock: threading.Lock = threading.Lock()
|
| 53 |
+
_watchdog_leaked_count: int = 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_watchdog_leak_count() -> int:
|
| 57 |
+
"""Return the cumulative number of watchdog threads that survived interrupt.
|
| 58 |
+
|
| 59 |
+
A non-zero value means at least one DuckDB worker thread was not stopped
|
| 60 |
+
cleanly and is still alive in the background. Production monitoring should
|
| 61 |
+
alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`.
|
| 62 |
+
"""
|
| 63 |
+
return _watchdog_leaked_count
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class QueryWatchdogEscalationError(RuntimeError):
|
| 67 |
+
"""DuckDB worker survived interrupt; the connection is no longer safe."""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass(frozen=True)
|
| 71 |
+
class TimedResult:
|
| 72 |
+
"""Output of :func:`execute_once_timed`.
|
| 73 |
+
|
| 74 |
+
``columns`` preserves DuckDB's cursor ``description`` order so callers
|
| 75 |
+
can emit a :class:`models.RunQueryResult` without re-executing the
|
| 76 |
+
query just to recover column names.
|
| 77 |
+
|
| 78 |
+
``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap
|
| 79 |
+
and the query produced strictly more rows than that cap; in that
|
| 80 |
+
case ``rows`` contains exactly ``max_rows + 1`` entries (the
|
| 81 |
+
one-over read that proves overflow). Callers that care about size
|
| 82 |
+
limits must branch on ``truncated`` rather than re-checking
|
| 83 |
+
``len(rows)`` against their cap.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
columns: list[str]
|
| 87 |
+
rows: list[tuple[Any, ...]]
|
| 88 |
+
elapsed_ms: float
|
| 89 |
+
truncated: bool = False
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _fetch_capped(
|
| 93 |
+
cursor: duckdb.DuckDBPyConnection,
|
| 94 |
+
max_rows: int,
|
| 95 |
+
) -> tuple[list[tuple[Any, ...]], bool]:
|
| 96 |
+
"""Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany.
|
| 97 |
+
|
| 98 |
+
Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the
|
| 99 |
+
cursor still has unread rows — we stopped on the first row past the
|
| 100 |
+
cap so the caller can signal overflow without materialising the
|
| 101 |
+
rest of a potentially enormous result set.
|
| 102 |
+
"""
|
| 103 |
+
# chunk=1024 trades a few extra Python calls for not over-fetching
|
| 104 |
+
# by orders of magnitude when results are modest. The +1 in the
|
| 105 |
+
# final budget is what makes overflow detectable.
|
| 106 |
+
rows: list[tuple[Any, ...]] = []
|
| 107 |
+
budget = max_rows + 1
|
| 108 |
+
while budget > 0:
|
| 109 |
+
batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget))
|
| 110 |
+
if not batch:
|
| 111 |
+
return rows, False
|
| 112 |
+
rows.extend(batch)
|
| 113 |
+
budget -= len(batch)
|
| 114 |
+
return rows, len(rows) > max_rows
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _iter_cursor_rows(
|
| 118 |
+
cursor: duckdb.DuckDBPyConnection,
|
| 119 |
+
) -> Iterator[tuple[Any, ...]]:
|
| 120 |
+
while True:
|
| 121 |
+
batch = cursor.fetchmany(_FETCH_CHUNK_ROWS)
|
| 122 |
+
if not batch:
|
| 123 |
+
return
|
| 124 |
+
yield from batch
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _run_worker_with_watchdog[T](
|
| 128 |
+
conn: duckdb.DuckDBPyConnection,
|
| 129 |
+
sql: str,
|
| 130 |
+
timeout_s: float,
|
| 131 |
+
worker: Callable[[], T],
|
| 132 |
+
) -> T:
|
| 133 |
+
result_holder: dict[str, object] = {}
|
| 134 |
+
|
| 135 |
+
def runner() -> None:
|
| 136 |
+
try:
|
| 137 |
+
result_holder["result"] = worker()
|
| 138 |
+
except BaseException as exc: # Must forward all failures from the worker thread.
|
| 139 |
+
result_holder["error"] = exc
|
| 140 |
+
|
| 141 |
+
thread = threading.Thread(target=runner, daemon=True)
|
| 142 |
+
thread.start()
|
| 143 |
+
thread.join(timeout_s)
|
| 144 |
+
if thread.is_alive():
|
| 145 |
+
# DuckDB's interrupt API is connection-scoped and thread-safe;
|
| 146 |
+
# we ask the query to unwind and then wait *unconditionally*
|
| 147 |
+
# for the worker to exit before surfacing the timeout to the
|
| 148 |
+
# caller. If we returned while the thread were still alive, it
|
| 149 |
+
# would retain access to ``conn`` and its result could race
|
| 150 |
+
# future queries on the same connection — a previously
|
| 151 |
+
# observed source of flaky post-timeout behaviour. In practice
|
| 152 |
+
# DuckDB's interrupt releases the worker within a handful of
|
| 153 |
+
# milliseconds; if the engine ever fails to honour interrupt
|
| 154 |
+
# the process will hang here, which is the correct failure
|
| 155 |
+
# mode for a connection whose state is no longer safe to
|
| 156 |
+
# reuse.
|
| 157 |
+
with contextlib.suppress(Exception):
|
| 158 |
+
conn.interrupt()
|
| 159 |
+
thread.join(INTERRUPT_GRACE_S)
|
| 160 |
+
if thread.is_alive():
|
| 161 |
+
global _watchdog_leaked_count
|
| 162 |
+
with _watchdog_leak_lock:
|
| 163 |
+
_watchdog_leaked_count += 1
|
| 164 |
+
leak_count = _watchdog_leaked_count
|
| 165 |
+
log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error
|
| 166 |
+
log_fn(
|
| 167 |
+
"query watchdog failed to stop worker after %.3fs timeout + %.3fs grace"
|
| 168 |
+
" (cumulative leaked threads: %d)",
|
| 169 |
+
timeout_s,
|
| 170 |
+
INTERRUPT_GRACE_S,
|
| 171 |
+
leak_count,
|
| 172 |
+
)
|
| 173 |
+
raise QueryWatchdogEscalationError(
|
| 174 |
+
f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}"
|
| 175 |
+
)
|
| 176 |
+
raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}")
|
| 177 |
+
if "error" in result_holder:
|
| 178 |
+
error = result_holder["error"]
|
| 179 |
+
assert isinstance(error, BaseException)
|
| 180 |
+
raise error
|
| 181 |
+
return cast(T, result_holder["result"])
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _run_with_watchdog(
|
| 185 |
+
conn: duckdb.DuckDBPyConnection,
|
| 186 |
+
sql: str,
|
| 187 |
+
timeout_s: float,
|
| 188 |
+
max_rows: int | None,
|
| 189 |
+
) -> TimedResult:
|
| 190 |
+
def worker() -> TimedResult:
|
| 191 |
+
start = time.perf_counter_ns()
|
| 192 |
+
cursor = conn.execute(sql)
|
| 193 |
+
columns = [d[0] for d in cursor.description] if cursor.description else []
|
| 194 |
+
if max_rows is None:
|
| 195 |
+
rows = cursor.fetchall()
|
| 196 |
+
truncated = False
|
| 197 |
+
else:
|
| 198 |
+
rows, truncated = _fetch_capped(cursor, max_rows)
|
| 199 |
+
elapsed_ns = time.perf_counter_ns() - start
|
| 200 |
+
return TimedResult(
|
| 201 |
+
columns=columns,
|
| 202 |
+
rows=rows,
|
| 203 |
+
elapsed_ms=elapsed_ns / 1_000_000.0,
|
| 204 |
+
truncated=truncated,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
|
| 208 |
+
assert isinstance(result, TimedResult)
|
| 209 |
+
return result
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def execute_once_timed(
|
| 213 |
+
conn: duckdb.DuckDBPyConnection,
|
| 214 |
+
sql: str,
|
| 215 |
+
*,
|
| 216 |
+
timeout_s: float = DEFAULT_TIMEOUT_S,
|
| 217 |
+
max_rows: int | None = None,
|
| 218 |
+
) -> tuple[list[tuple[Any, ...]], float]:
|
| 219 |
+
"""Single timed execution — returns ``(rows, elapsed_ms)``.
|
| 220 |
+
|
| 221 |
+
Thin wrapper for callers that don't need column metadata or the
|
| 222 |
+
truncation flag.
|
| 223 |
+
"""
|
| 224 |
+
res = _run_with_watchdog(conn, sql, timeout_s, max_rows)
|
| 225 |
+
return res.rows, res.elapsed_ms
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def execute_once_with_columns(
|
| 229 |
+
conn: duckdb.DuckDBPyConnection,
|
| 230 |
+
sql: str,
|
| 231 |
+
*,
|
| 232 |
+
timeout_s: float = DEFAULT_TIMEOUT_S,
|
| 233 |
+
max_rows: int | None = None,
|
| 234 |
+
) -> TimedResult:
|
| 235 |
+
"""Single timed execution — returns columns + rows + elapsed_ms.
|
| 236 |
+
|
| 237 |
+
When ``max_rows`` is supplied, the fetch aborts at the first row
|
| 238 |
+
past the cap and ``TimedResult.truncated`` is set. The elapsed
|
| 239 |
+
milliseconds in that case reflect the partial scan, not the query's
|
| 240 |
+
would-be completion time — a truncated read is a *hard error* in
|
| 241 |
+
agent-facing code paths, not a performance measurement.
|
| 242 |
+
"""
|
| 243 |
+
return _run_with_watchdog(conn, sql, timeout_s, max_rows)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def execute_hash_timed(
|
| 247 |
+
conn: duckdb.DuckDBPyConnection,
|
| 248 |
+
sql: str,
|
| 249 |
+
*,
|
| 250 |
+
timeout_s: float = DEFAULT_TIMEOUT_S,
|
| 251 |
+
) -> tuple[str, float]:
|
| 252 |
+
"""Single timed execution — returns ``(result_hash, elapsed_ms)``.
|
| 253 |
+
|
| 254 |
+
Unlike :func:`execute_once_timed`, this drains the cursor via
|
| 255 |
+
``fetchmany`` and hashes rows incrementally, so callers can compare a
|
| 256 |
+
large final result to ground truth without materializing the full row
|
| 257 |
+
set in Python memory.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def worker() -> tuple[str, float]:
|
| 261 |
+
start = time.perf_counter_ns()
|
| 262 |
+
cursor = conn.execute(sql)
|
| 263 |
+
result_hash = canonical_row_hash(_iter_cursor_rows(cursor))
|
| 264 |
+
elapsed_ns = time.perf_counter_ns() - start
|
| 265 |
+
return result_hash, elapsed_ns / 1_000_000.0
|
| 266 |
+
|
| 267 |
+
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
|
| 268 |
+
result_hash, elapsed_ms = result
|
| 269 |
+
assert isinstance(result_hash, str)
|
| 270 |
+
assert isinstance(elapsed_ms, float)
|
| 271 |
+
return result_hash, elapsed_ms
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def median_of_3_warm_ms(
|
| 275 |
+
conn: duckdb.DuckDBPyConnection,
|
| 276 |
+
sql: str,
|
| 277 |
+
*,
|
| 278 |
+
timeout_s: float = DEFAULT_TIMEOUT_S,
|
| 279 |
+
) -> float:
|
| 280 |
+
"""Warm cache, then median-of-3 timed runs. Returns milliseconds."""
|
| 281 |
+
_run_with_watchdog(conn, sql, timeout_s, None)
|
| 282 |
+
timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)]
|
| 283 |
+
timings.sort()
|
| 284 |
+
return timings[1]
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
__all__ = [
|
| 288 |
+
"DEFAULT_TIMEOUT_S",
|
| 289 |
+
"INTERRUPT_GRACE_S",
|
| 290 |
+
"MAX_LEAKED_WATCHDOG_THREADS",
|
| 291 |
+
"QueryWatchdogEscalationError",
|
| 292 |
+
"TimedResult",
|
| 293 |
+
"execute_hash_timed",
|
| 294 |
+
"execute_once_timed",
|
| 295 |
+
"execute_once_with_columns",
|
| 296 |
+
"get_watchdog_leak_count",
|
| 297 |
+
"median_of_3_warm_ms",
|
| 298 |
+
]
|
engine/reward.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax).
|
| 2 |
+
|
| 3 |
+
Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`):
|
| 4 |
+
|
| 5 |
+
r_correct correctness vs ground-truth hash, gated on ≥ 1.2× speedup
|
| 6 |
+
r_drift bonus/penalty for (not) adapting to post-drift identifiers
|
| 7 |
+
r_speedup tanh-shaped speedup bonus, gated on r_correct > 0
|
| 8 |
+
r_step_tax base step tax plus bounded productive-action rebates
|
| 9 |
+
r_gatekeepers escalating tool-error / repeat-failing / no-op penalties
|
| 10 |
+
r_consult_dba DBA-oracle consult penalties (feature-flagged; 0 when off)
|
| 11 |
+
|
| 12 |
+
All child rubrics share a single ``ctx_provider`` that returns the private
|
| 13 |
+
:class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric
|
| 14 |
+
stateless relative to the environment and makes each component
|
| 15 |
+
individually unit-testable with a synthesized triple
|
| 16 |
+
``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import math
|
| 22 |
+
import re
|
| 23 |
+
from collections.abc import Callable
|
| 24 |
+
from typing import TYPE_CHECKING
|
| 25 |
+
|
| 26 |
+
from openenv.core.rubrics import Rubric
|
| 27 |
+
|
| 28 |
+
from models import (
|
| 29 |
+
SqlDriftAction,
|
| 30 |
+
SqlDriftObservation,
|
| 31 |
+
SubmitRewriteResult,
|
| 32 |
+
ToolError,
|
| 33 |
+
ToolName,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING:
|
| 37 |
+
from engine.runtime import RuntimeEpisodeState
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Tunable thresholds pulled out to module level so tests and future
|
| 41 |
+
# curriculum code share a single source of truth.
|
| 42 |
+
SPEEDUP_MIN: float = 1.2
|
| 43 |
+
SPEEDUP_CAP_FOR_INFTY: float = 64.0
|
| 44 |
+
STEP_TAX: float = -0.03
|
| 45 |
+
STEP_REBATE_LIST_TABLES: float = 0.04
|
| 46 |
+
STEP_REBATE_DESCRIBE_TABLE: float = 0.06
|
| 47 |
+
STEP_REBATE_SAMPLE_ROWS: float = 0.05
|
| 48 |
+
STEP_REBATE_RUN_QUERY: float = 0.04
|
| 49 |
+
STEP_REBATE_EXPLAIN_QUERY: float = 0.04
|
| 50 |
+
STEP_REBATE_READ_CHANGELOG: float = 0.08
|
| 51 |
+
|
| 52 |
+
GATE_MALFORMED_TOOL_CALL: float = -0.3
|
| 53 |
+
GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1
|
| 54 |
+
GATE_REPEAT_FAILING_QUERY: float = -0.1
|
| 55 |
+
GATE_BASELINE_VERBATIM: float = -0.2
|
| 56 |
+
_MAX_ESCALATION_STEPS: int = 3
|
| 57 |
+
|
| 58 |
+
CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# =============================================================================
|
| 62 |
+
# Helpers
|
| 63 |
+
# =============================================================================
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def canonicalize_sql(sql: str) -> str:
|
| 67 |
+
"""Whitespace/case/alias-insensitive canonical form.
|
| 68 |
+
|
| 69 |
+
Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree;
|
| 70 |
+
falls back to a simple whitespace fold if sqlglot rejects the SQL
|
| 71 |
+
(e.g. during the baseline-verbatim check on an agent-submitted blob).
|
| 72 |
+
"""
|
| 73 |
+
try:
|
| 74 |
+
import sqlglot
|
| 75 |
+
|
| 76 |
+
expr = sqlglot.parse_one(sql, dialect="duckdb")
|
| 77 |
+
return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower()
|
| 78 |
+
except Exception:
|
| 79 |
+
return " ".join(sql.lower().split())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
_IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _extract_identifiers(sql: str) -> frozenset[str]:
|
| 86 |
+
"""Loose case-preserving identifier harvest.
|
| 87 |
+
|
| 88 |
+
Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'`
|
| 89 |
+
does not leak 'ACTIVE' into the identifier set.
|
| 90 |
+
"""
|
| 91 |
+
stripped = re.sub(r"'[^']*'", "", sql)
|
| 92 |
+
return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _extract_column_references(sql: str) -> frozenset[str]:
|
| 96 |
+
"""Column-reference identifiers via sqlglot AST (excludes alias labels).
|
| 97 |
+
|
| 98 |
+
For drift-adapt scoring, ``SELECT account_id AS user_id`` references
|
| 99 |
+
the new column and merely labels the output — the alias must not
|
| 100 |
+
count as a surviving pre-drift marker. Falls back to the regex
|
| 101 |
+
extractor on parse failure so malformed SQL still scores something.
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
import sqlglot
|
| 105 |
+
expr = sqlglot.parse_one(sql, dialect="duckdb")
|
| 106 |
+
except Exception:
|
| 107 |
+
return _extract_identifiers(sql)
|
| 108 |
+
if expr is None:
|
| 109 |
+
return _extract_identifiers(sql)
|
| 110 |
+
return frozenset(
|
| 111 |
+
n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _literals(sql: str) -> frozenset[str]:
|
| 116 |
+
"""All `'..'`-quoted string literals in `sql`."""
|
| 117 |
+
return frozenset(re.findall(r"'([^']*)'", sql))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
_AGENT_MS_EPSILON: float = 1e-6
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def effective_speedup(rt: RuntimeEpisodeState) -> float | None:
|
| 124 |
+
"""Compute ``effective_speedup`` from the runtime snapshot (speedup rubric).
|
| 125 |
+
|
| 126 |
+
Single source of truth for the speedup number used across the code
|
| 127 |
+
base — rubric scoring, the skill library's ``avg_speedup`` field,
|
| 128 |
+
and the training evaluator all route through here so divergent
|
| 129 |
+
definitions cannot drift apart.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
|
| 133 |
+
* ``None`` — no submission has happened yet. Callers that need a
|
| 134 |
+
numeric default (e.g. the rubric, which is only invoked
|
| 135 |
+
post-submission) should verify ``rt.submitted`` first.
|
| 136 |
+
* ``+∞`` — drift has fired and the pre-drift baseline SQL no longer
|
| 137 |
+
executes against the post-drift schema; any correct submission is
|
| 138 |
+
definitionally "infinitely faster" than an unrunnable baseline.
|
| 139 |
+
* ``baseline_ms / max(agent_ms, ε)`` otherwise. A tiny ``ε`` clamp
|
| 140 |
+
guards against zero/negative timings from sub-microsecond queries
|
| 141 |
+
and treats them as "as fast as possible" (very large, finite
|
| 142 |
+
speedup) rather than silently collapsing the reward.
|
| 143 |
+
"""
|
| 144 |
+
if rt.submitted_runtime_ms is None:
|
| 145 |
+
return None
|
| 146 |
+
if rt.drift_fired and rt.baseline_postdrift_raises:
|
| 147 |
+
return math.inf
|
| 148 |
+
agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON)
|
| 149 |
+
return rt.baseline_runtime_ms / agent_ms
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _speedup_for_reward(rt: RuntimeEpisodeState) -> float:
|
| 153 |
+
"""Rubric-facing speedup that never returns ``None``.
|
| 154 |
+
|
| 155 |
+
The rubric is only invoked once ``rt.submitted`` is True, so
|
| 156 |
+
:func:`effective_speedup` cannot return ``None`` from these call
|
| 157 |
+
sites; we assert that and coerce to ``0.0`` defensively if it ever
|
| 158 |
+
does (prevents a silent ``TypeError`` inside the reward math).
|
| 159 |
+
"""
|
| 160 |
+
val = effective_speedup(rt)
|
| 161 |
+
return 0.0 if val is None else val
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _is_terminal_submission(
|
| 165 |
+
action: SqlDriftAction,
|
| 166 |
+
observation: SqlDriftObservation,
|
| 167 |
+
rt: RuntimeEpisodeState,
|
| 168 |
+
) -> bool:
|
| 169 |
+
"""True iff this step is the submission step.
|
| 170 |
+
|
| 171 |
+
The env sets ``done=True`` on a successful submission and attaches a
|
| 172 |
+
:class:`SubmitRewriteResult`; we gate terminal rewards on both
|
| 173 |
+
signals so repeated rubric calls on an unchanged state don't
|
| 174 |
+
double-score.
|
| 175 |
+
"""
|
| 176 |
+
if not rt.submitted:
|
| 177 |
+
return False
|
| 178 |
+
if action.tool != ToolName.SUBMIT_REWRITE:
|
| 179 |
+
return False
|
| 180 |
+
tr = observation.tool_result
|
| 181 |
+
return isinstance(tr, SubmitRewriteResult)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _gt_hash(rt: RuntimeEpisodeState) -> str | None:
|
| 185 |
+
if rt.drift_fired and rt.gt_result_hash_postdrift is not None:
|
| 186 |
+
return rt.gt_result_hash_postdrift
|
| 187 |
+
return rt.gt_result_hash_predrift
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# =============================================================================
|
| 191 |
+
# Child rubrics
|
| 192 |
+
# =============================================================================
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class _CtxChild(Rubric):
|
| 196 |
+
"""Base child rubric sharing the ctx provider."""
|
| 197 |
+
|
| 198 |
+
def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
|
| 199 |
+
super().__init__()
|
| 200 |
+
object.__setattr__(self, "_ctx", ctx_provider)
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
action: SqlDriftAction,
|
| 205 |
+
observation: SqlDriftObservation,
|
| 206 |
+
) -> float:
|
| 207 |
+
raise NotImplementedError
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class Correctness(_CtxChild):
|
| 211 |
+
"""Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup."""
|
| 212 |
+
|
| 213 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 214 |
+
rt = self._ctx()
|
| 215 |
+
if not _is_terminal_submission(action, observation, rt):
|
| 216 |
+
return 0.0
|
| 217 |
+
gt = _gt_hash(rt)
|
| 218 |
+
agent_hash = rt.submitted_result_hash
|
| 219 |
+
if gt is None or agent_hash is None:
|
| 220 |
+
return 0.0
|
| 221 |
+
if agent_hash != gt:
|
| 222 |
+
return -1.0
|
| 223 |
+
speedup = _speedup_for_reward(rt)
|
| 224 |
+
if speedup >= SPEEDUP_MIN:
|
| 225 |
+
return 1.0
|
| 226 |
+
return 0.5
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class DriftAdapt(_CtxChild):
|
| 230 |
+
"""+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only
|
| 231 |
+
submission after drift fired.
|
| 232 |
+
|
| 233 |
+
Adaptation is detected against two scenario-declared identifier sets:
|
| 234 |
+
|
| 235 |
+
* ``postdrift_identifiers`` — identifiers/literals that only a
|
| 236 |
+
correct post-drift rewrite will introduce (e.g. ``account_id``
|
| 237 |
+
after a column rename, ``'ACTIVE'`` after an enum split).
|
| 238 |
+
* ``predrift_identifiers`` — identifiers/literals a submission that
|
| 239 |
+
ignored the drift would retain (e.g. ``user_id``, ``'active'``,
|
| 240 |
+
the ISO anchor strings under date-format drift).
|
| 241 |
+
|
| 242 |
+
A submission is considered "adapted" when it either surfaces a
|
| 243 |
+
post-drift marker *or* the scenario declares no distinctive
|
| 244 |
+
post-drift identifiers (e.g. date-format drift keeps the same
|
| 245 |
+
column name and only the literal shape changes) AND it does not
|
| 246 |
+
retain any pre-drift marker. The penalty fires only when the
|
| 247 |
+
submission still carries pre-drift markers AND produced the wrong
|
| 248 |
+
post-drift result — so a merely partial rewrite (neither pre-
|
| 249 |
+
nor post-flavoured) never earns a penalty it can't diagnose.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 253 |
+
rt = self._ctx()
|
| 254 |
+
# Only drift scenarios participate.
|
| 255 |
+
if rt.gt_result_hash_postdrift is None and not rt.drift_fired:
|
| 256 |
+
return 0.0
|
| 257 |
+
if not _is_terminal_submission(action, observation, rt):
|
| 258 |
+
return 0.0
|
| 259 |
+
inst = getattr(rt, "instance", None)
|
| 260 |
+
post_ids: frozenset[str] = (
|
| 261 |
+
getattr(inst, "postdrift_identifiers", frozenset()) or frozenset()
|
| 262 |
+
)
|
| 263 |
+
pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset()
|
| 264 |
+
agent_sql = rt.submitted_sql or ""
|
| 265 |
+
|
| 266 |
+
idents = _extract_column_references(agent_sql)
|
| 267 |
+
literals = _literals(agent_sql)
|
| 268 |
+
markers = idents | literals
|
| 269 |
+
uses_post = bool(post_ids & markers)
|
| 270 |
+
uses_pre = bool(pre_ids & markers)
|
| 271 |
+
|
| 272 |
+
# Treat "no distinctive post identifier" scenarios as
|
| 273 |
+
# satisfied by absence-of-pre (see class docstring).
|
| 274 |
+
adapted = (uses_post or not post_ids) and not uses_pre
|
| 275 |
+
|
| 276 |
+
agent_hash = rt.submitted_result_hash
|
| 277 |
+
gt_post = rt.gt_result_hash_postdrift
|
| 278 |
+
|
| 279 |
+
if rt.drift_fired and agent_hash == gt_post and adapted:
|
| 280 |
+
return 0.5
|
| 281 |
+
if rt.drift_fired and uses_pre and agent_hash != gt_post:
|
| 282 |
+
return -0.5
|
| 283 |
+
return 0.0
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class Speedup(_CtxChild):
|
| 287 |
+
"""Terminal-only, gated on r_correct > 0: 0.3·tanh(log2(speedup)/3)."""
|
| 288 |
+
|
| 289 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 290 |
+
rt = self._ctx()
|
| 291 |
+
if not _is_terminal_submission(action, observation, rt):
|
| 292 |
+
return 0.0
|
| 293 |
+
gt = _gt_hash(rt)
|
| 294 |
+
if gt is None or rt.submitted_result_hash != gt:
|
| 295 |
+
return 0.0
|
| 296 |
+
raw = _speedup_for_reward(rt)
|
| 297 |
+
if math.isinf(raw):
|
| 298 |
+
raw = SPEEDUP_CAP_FOR_INFTY
|
| 299 |
+
if raw <= 1.0:
|
| 300 |
+
return 0.0
|
| 301 |
+
return 0.3 * math.tanh(math.log2(raw) / 3.0)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class StepTax(_CtxChild):
|
| 305 |
+
"""Base step tax plus bounded rebates for productive exploration."""
|
| 306 |
+
|
| 307 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 308 |
+
rt = self._ctx()
|
| 309 |
+
if _is_terminal_submission(action, observation, rt):
|
| 310 |
+
return 0.0
|
| 311 |
+
rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0)))
|
| 312 |
+
return STEP_TAX + rebate
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class Gatekeepers(_CtxChild):
|
| 316 |
+
"""Sum of three independent penalties; repeats escalate up to a cap."""
|
| 317 |
+
|
| 318 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 319 |
+
rt = self._ctx()
|
| 320 |
+
penalty = 0.0
|
| 321 |
+
# 1. Malformed / failed tool call — ToolError emitted this step.
|
| 322 |
+
if isinstance(observation.tool_result, ToolError):
|
| 323 |
+
penalty += GATE_MALFORMED_TOOL_CALL
|
| 324 |
+
streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1)
|
| 325 |
+
penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS)
|
| 326 |
+
# 2. Repeat failing query — env marks the flag on the runtime
|
| 327 |
+
# state immediately before invoking the rubric.
|
| 328 |
+
repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1)
|
| 329 |
+
if repeats > 0:
|
| 330 |
+
penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS)
|
| 331 |
+
# 3. Baseline-verbatim submission (Rev-3 gate — stacks with
|
| 332 |
+
# correctness's +0.5 partial to cap the no-op rewrite at +0.3).
|
| 333 |
+
if (
|
| 334 |
+
action.tool == ToolName.SUBMIT_REWRITE
|
| 335 |
+
and _is_terminal_submission(action, observation, rt)
|
| 336 |
+
and rt.submitted_sql_canonical == rt.baseline_sql_canonical
|
| 337 |
+
):
|
| 338 |
+
penalty += GATE_BASELINE_VERBATIM
|
| 339 |
+
return penalty
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class ConsultDBA(_CtxChild):
|
| 343 |
+
"""Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on."""
|
| 344 |
+
|
| 345 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 346 |
+
rt = self._ctx()
|
| 347 |
+
oracle_enabled = getattr(rt, "dba_oracle_enabled", False)
|
| 348 |
+
if not oracle_enabled:
|
| 349 |
+
return 0.0
|
| 350 |
+
if action.tool != ToolName.CONSULT_DBA:
|
| 351 |
+
return 0.0
|
| 352 |
+
# Count the consult THIS step by indexing into the escalation
|
| 353 |
+
# table using the pre-increment value (env increments on the same step).
|
| 354 |
+
tier = min(rt.consultations_used, len(CONSULT_ESCALATION))
|
| 355 |
+
if tier <= 0:
|
| 356 |
+
return CONSULT_ESCALATION[0]
|
| 357 |
+
return CONSULT_ESCALATION[tier - 1]
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# =============================================================================
|
| 361 |
+
# Composite
|
| 362 |
+
# =============================================================================
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class SqlDriftRubric(Rubric):
|
| 366 |
+
"""Composite rubric: sum of six children.
|
| 367 |
+
|
| 368 |
+
Registration as attributes auto-enrolls them in
|
| 369 |
+
:meth:`Rubric.named_rubrics` so training loops can introspect
|
| 370 |
+
per-component scores.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
|
| 374 |
+
super().__init__()
|
| 375 |
+
# NOTE: order matters — correctness must populate last_score before
|
| 376 |
+
# speedup reads it via the shared ctx_provider (both are pure
|
| 377 |
+
# functions of the runtime state, so identical output — but the
|
| 378 |
+
# explicit ordering documents the intent).
|
| 379 |
+
self.correctness = Correctness(ctx_provider)
|
| 380 |
+
self.drift_adapt = DriftAdapt(ctx_provider)
|
| 381 |
+
self.speedup = Speedup(ctx_provider)
|
| 382 |
+
self.step_tax = StepTax(ctx_provider)
|
| 383 |
+
self.gatekeepers = Gatekeepers(ctx_provider)
|
| 384 |
+
self.consult_dba = ConsultDBA(ctx_provider)
|
| 385 |
+
|
| 386 |
+
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
|
| 387 |
+
total = (
|
| 388 |
+
self.correctness(action, observation)
|
| 389 |
+
+ self.drift_adapt(action, observation)
|
| 390 |
+
+ self.speedup(action, observation)
|
| 391 |
+
+ self.step_tax(action, observation)
|
| 392 |
+
+ self.gatekeepers(action, observation)
|
| 393 |
+
+ self.consult_dba(action, observation)
|
| 394 |
+
)
|
| 395 |
+
return total
|
| 396 |
+
|
| 397 |
+
def component_scores(self) -> dict[str, float]:
|
| 398 |
+
"""Return the most-recent per-component scores, keyed for W&B.
|
| 399 |
+
|
| 400 |
+
Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation
|
| 401 |
+
envelope and the demo plots agree on a stable schema.
|
| 402 |
+
"""
|
| 403 |
+
return {
|
| 404 |
+
"r_correct": float(self.correctness.last_score or 0.0),
|
| 405 |
+
"r_drift": float(self.drift_adapt.last_score or 0.0),
|
| 406 |
+
"r_speedup": float(self.speedup.last_score or 0.0),
|
| 407 |
+
"r_step_tax": float(self.step_tax.last_score or 0.0),
|
| 408 |
+
"r_gatekeepers": float(self.gatekeepers.last_score or 0.0),
|
| 409 |
+
"r_consult_dba": float(self.consult_dba.last_score or 0.0),
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
__all__ = [
|
| 414 |
+
"CONSULT_ESCALATION",
|
| 415 |
+
"ConsultDBA",
|
| 416 |
+
"Correctness",
|
| 417 |
+
"DriftAdapt",
|
| 418 |
+
"GATE_BASELINE_VERBATIM",
|
| 419 |
+
"GATE_CONSECUTIVE_TOOL_ERROR",
|
| 420 |
+
"GATE_MALFORMED_TOOL_CALL",
|
| 421 |
+
"GATE_REPEAT_FAILING_QUERY",
|
| 422 |
+
"Gatekeepers",
|
| 423 |
+
"SPEEDUP_CAP_FOR_INFTY",
|
| 424 |
+
"SPEEDUP_MIN",
|
| 425 |
+
"STEP_REBATE_DESCRIBE_TABLE",
|
| 426 |
+
"STEP_REBATE_EXPLAIN_QUERY",
|
| 427 |
+
"STEP_REBATE_LIST_TABLES",
|
| 428 |
+
"STEP_REBATE_READ_CHANGELOG",
|
| 429 |
+
"STEP_REBATE_RUN_QUERY",
|
| 430 |
+
"STEP_REBATE_SAMPLE_ROWS",
|
| 431 |
+
"STEP_TAX",
|
| 432 |
+
"Speedup",
|
| 433 |
+
"SqlDriftRubric",
|
| 434 |
+
"StepTax",
|
| 435 |
+
"canonicalize_sql",
|
| 436 |
+
"effective_speedup",
|
| 437 |
+
]
|
engine/runtime.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Private per-episode runtime state (not exposed over the wire).
|
| 2 |
+
|
| 3 |
+
This module is imported by both :mod:`engine.reward` and
|
| 4 |
+
:mod:`server.sql_drift_env_environment` — keeping it out of ``server/``
|
| 5 |
+
avoids the import cycle ``engine.reward → server → engine.reward``.
|
| 6 |
+
|
| 7 |
+
NEVER serialize or expose this over any endpoint. The public state
|
| 8 |
+
projection lives in :class:`models.SqlDriftState`.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import TYPE_CHECKING, Any
|
| 15 |
+
|
| 16 |
+
from models import EpisodePhase
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
import duckdb
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class RuntimeEpisodeState:
|
| 24 |
+
"""Private per-episode state — holds secrets (DB handle, ground truth)."""
|
| 25 |
+
|
| 26 |
+
episode_id: str
|
| 27 |
+
seed: int
|
| 28 |
+
scenario_id: str
|
| 29 |
+
|
| 30 |
+
instance: Any # scenarios.base.ScenarioInstance — forward ref avoids import cycle
|
| 31 |
+
conn: duckdb.DuckDBPyConnection
|
| 32 |
+
|
| 33 |
+
# Ground truth hashes (never exposed)
|
| 34 |
+
gt_result_hash_predrift: str
|
| 35 |
+
gt_result_hash_postdrift: str | None
|
| 36 |
+
|
| 37 |
+
# Baseline measurements (used by speedup + baseline-verbatim gates)
|
| 38 |
+
baseline_runtime_ms: float
|
| 39 |
+
baseline_tokens: int
|
| 40 |
+
baseline_sql_canonical: str
|
| 41 |
+
baseline_postdrift_raises: bool
|
| 42 |
+
|
| 43 |
+
# Drift timing (scheduled step + cooldown relative to first run_query).
|
| 44 |
+
drift_scheduled_step: int | None
|
| 45 |
+
connection_poisoned: bool = False
|
| 46 |
+
drift_fired_step: int | None = None
|
| 47 |
+
first_run_query_step: int | None = None
|
| 48 |
+
|
| 49 |
+
# Episode progression
|
| 50 |
+
step_count: int = 0
|
| 51 |
+
phase: EpisodePhase = EpisodePhase.DIAGNOSE
|
| 52 |
+
budget_steps: int = 25
|
| 53 |
+
|
| 54 |
+
# Per-step bookkeeping
|
| 55 |
+
failed_query_hashes: set[str] = field(default_factory=set)
|
| 56 |
+
failed_query_counts: dict[str, int] = field(default_factory=dict)
|
| 57 |
+
changelog_entries: list[str] = field(default_factory=list)
|
| 58 |
+
consultations_used: int = 0
|
| 59 |
+
listed_tables_rewarded: bool = False
|
| 60 |
+
described_tables_rewarded: set[str] = field(default_factory=set)
|
| 61 |
+
sampled_tables_rewarded: set[str] = field(default_factory=set)
|
| 62 |
+
run_query_rewarded: bool = False
|
| 63 |
+
explain_query_rewarded: bool = False
|
| 64 |
+
changelog_rewarded_after_drift: bool = False
|
| 65 |
+
|
| 66 |
+
# Phase-machine bookkeeping — counts successful diagnostic tool calls
|
| 67 |
+
# (list_tables, describe_table, sample_rows, run_query, explain_query,
|
| 68 |
+
# read_changelog). The DIAGNOSE → REWRITE transition fires the first
|
| 69 |
+
# time this becomes non-zero; SUBMIT_REWRITE is rejected while this is
|
| 70 |
+
# still zero (ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE).
|
| 71 |
+
diagnostic_actions_taken: int = 0
|
| 72 |
+
|
| 73 |
+
# Submission state — populated once SUBMIT_REWRITE is accepted
|
| 74 |
+
submitted: bool = False
|
| 75 |
+
submitted_sql: str | None = None
|
| 76 |
+
submitted_sql_canonical: str | None = None
|
| 77 |
+
submitted_result_hash: str | None = None
|
| 78 |
+
submitted_runtime_ms: float | None = None
|
| 79 |
+
|
| 80 |
+
# Last-step signal — consumed by the rubric to compute per-step penalties
|
| 81 |
+
last_step_was_tool_error: bool = False
|
| 82 |
+
last_step_was_repeat_failing_query: bool = False
|
| 83 |
+
last_step_repeat_failing_query_count: int = 0
|
| 84 |
+
last_step_productive_rebate: float = 0.0
|
| 85 |
+
consecutive_tool_errors: int = 0
|
| 86 |
+
|
| 87 |
+
# Drift acknowledgement — set True the first time the agent reads the
|
| 88 |
+
# changelog or observes post-drift schema identifiers in a query.
|
| 89 |
+
drift_acknowledged: bool = False
|
| 90 |
+
|
| 91 |
+
# DBA Oracle feature flag (read by the ConsultDBA child rubric).
|
| 92 |
+
# Always False unless explicitly enabled at reset (kwarg or env var).
|
| 93 |
+
dba_oracle_enabled: bool = False
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def drift_fired(self) -> bool:
|
| 97 |
+
return self.drift_fired_step is not None
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def budget_steps_remaining(self) -> int:
|
| 101 |
+
return max(0, self.budget_steps - self.step_count)
|
engine/verifier.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic result verification.
|
| 2 |
+
|
| 3 |
+
Canonicalizes floats to `FLOAT_PRECISION` decimal places and treats NULL
|
| 4 |
+
uniformly so that two result sets with the same semantic content hash to
|
| 5 |
+
the same digest regardless of row order, floating-point noise, or None vs
|
| 6 |
+
SQL NULL representation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import hashlib
|
| 12 |
+
from collections.abc import Iterable
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
NULL_SENTINEL = "\x00NULL\x00"
|
| 16 |
+
FLOAT_PRECISION = 6
|
| 17 |
+
_DIGEST_MODULUS = 1 << 256
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _normalize_value(v: Any) -> Any:
|
| 21 |
+
if v is None:
|
| 22 |
+
return NULL_SENTINEL
|
| 23 |
+
if isinstance(v, float):
|
| 24 |
+
# NaN hashes to itself here — a NaN in rows yields a deterministic
|
| 25 |
+
# digest, but two NaNs round-trip to different representations when
|
| 26 |
+
# we `repr`. Guard explicitly.
|
| 27 |
+
if v != v: # NaN
|
| 28 |
+
return NULL_SENTINEL
|
| 29 |
+
return round(v, FLOAT_PRECISION)
|
| 30 |
+
return v
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _row_digest_int(row: Iterable[Any]) -> int:
|
| 34 |
+
normalized = tuple(_normalize_value(v) for v in row)
|
| 35 |
+
digest = hashlib.sha256(repr(normalized).encode()).digest()
|
| 36 |
+
return int.from_bytes(digest, "big", signed=False)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def canonical_row_hash(rows: Iterable[Iterable[Any]]) -> str:
|
| 40 |
+
"""Order-independent hash of a result set.
|
| 41 |
+
|
| 42 |
+
This stays order-independent and duplicate-sensitive without
|
| 43 |
+
materializing the full result in memory. Each normalized row is
|
| 44 |
+
hashed once, then folded into three commutative accumulators so the
|
| 45 |
+
final digest is stable across row order and Python processes.
|
| 46 |
+
"""
|
| 47 |
+
row_count = 0
|
| 48 |
+
sum_acc = 0
|
| 49 |
+
sumsq_acc = 0
|
| 50 |
+
xor_acc = 0
|
| 51 |
+
for row in rows:
|
| 52 |
+
row_count += 1
|
| 53 |
+
row_digest = _row_digest_int(row)
|
| 54 |
+
sum_acc = (sum_acc + row_digest) % _DIGEST_MODULUS
|
| 55 |
+
sumsq_acc = (sumsq_acc + ((row_digest * row_digest) % _DIGEST_MODULUS)) % _DIGEST_MODULUS
|
| 56 |
+
xor_acc ^= row_digest
|
| 57 |
+
payload = b"".join(
|
| 58 |
+
(
|
| 59 |
+
row_count.to_bytes(32, "big", signed=False),
|
| 60 |
+
sum_acc.to_bytes(32, "big", signed=False),
|
| 61 |
+
sumsq_acc.to_bytes(32, "big", signed=False),
|
| 62 |
+
xor_acc.to_bytes(32, "big", signed=False),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
return hashlib.sha256(payload).hexdigest()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def result_matches(agent_rows: Iterable[Iterable[Any]], gt_hash: str) -> bool:
|
| 69 |
+
"""True if `agent_rows` canonicalizes to the ground-truth hash."""
|
| 70 |
+
return canonical_row_hash(agent_rows) == gt_hash
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def row_set_jaccard(a: Iterable[Iterable[Any]], b: Iterable[Iterable[Any]]) -> float:
|
| 74 |
+
"""Jaccard over normalized row sets (order- and duplicate-insensitive).
|
| 75 |
+
|
| 76 |
+
Each input row is normalised with :func:`_normalize_value` and
|
| 77 |
+
collapsed into a :class:`frozenset`-style Python ``set``, so rows
|
| 78 |
+
that repeat within a single result are counted once. This is
|
| 79 |
+
deliberately *not* a multiset Jaccard — multiset semantics would
|
| 80 |
+
punish correct queries that legitimately emit duplicates more
|
| 81 |
+
harshly than intended.
|
| 82 |
+
|
| 83 |
+
Not used by the lean reward today, but kept covered by tests so
|
| 84 |
+
we can opt in later without rework.
|
| 85 |
+
"""
|
| 86 |
+
norm_a = {tuple(_normalize_value(v) for v in row) for row in a}
|
| 87 |
+
norm_b = {tuple(_normalize_value(v) for v in row) for row in b}
|
| 88 |
+
if not norm_a and not norm_b:
|
| 89 |
+
return 1.0
|
| 90 |
+
union = norm_a | norm_b
|
| 91 |
+
inter = norm_a & norm_b
|
| 92 |
+
return len(inter) / len(union)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
__all__ = [
|
| 96 |
+
"FLOAT_PRECISION",
|
| 97 |
+
"NULL_SENTINEL",
|
| 98 |
+
"canonical_row_hash",
|
| 99 |
+
"result_matches",
|
| 100 |
+
"row_set_jaccard",
|
| 101 |
+
]
|
models.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public data models for SQLDrift.
|
| 2 |
+
|
| 3 |
+
Rev 3 design notes enforced here:
|
| 4 |
+
|
| 5 |
+
- Action is a discriminated union over a public `kind: Literal[...]` tag on
|
| 6 |
+
each payload sub-model. Pydantic v2 forbids leading-underscore names as
|
| 7 |
+
discriminator keys (reserved for private attrs), so we keep the tag public.
|
| 8 |
+
- `SqlDriftAction` cross-validates that the envelope-level `tool` matches
|
| 9 |
+
`payload.kind` (prevents inconsistent envelopes from being constructed).
|
| 10 |
+
- `SqlDriftObservation.tool_result` is itself a discriminated union over the
|
| 11 |
+
eight concrete result types plus `ToolError` (for in-env semantic failures;
|
| 12 |
+
envelope-level `ValidationError` is a transport-layer concern, not an in-env code).
|
| 13 |
+
- `SqlDriftState` is the public state snapshot shipped over `/state`. It
|
| 14 |
+
never carries ground truth, DB handles, baseline runtime, or seeds;
|
| 15 |
+
`extra="forbid"` guarantees no accidental leak as new fields are added.
|
| 16 |
+
The private `RuntimeEpisodeState` lives in :mod:`engine.runtime`.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from enum import StrEnum
|
| 22 |
+
from typing import Annotated, Any, Literal
|
| 23 |
+
|
| 24 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 25 |
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
| 26 |
+
from pydantic_core import PydanticCustomError
|
| 27 |
+
|
| 28 |
+
# =============================================================================
|
| 29 |
+
# Enums
|
| 30 |
+
# =============================================================================
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ToolName(StrEnum):
|
| 34 |
+
LIST_TABLES = "list_tables"
|
| 35 |
+
DESCRIBE_TABLE = "describe_table"
|
| 36 |
+
SAMPLE_ROWS = "sample_rows"
|
| 37 |
+
RUN_QUERY = "run_query"
|
| 38 |
+
EXPLAIN_QUERY = "explain_query"
|
| 39 |
+
READ_CHANGELOG = "read_changelog"
|
| 40 |
+
SUBMIT_REWRITE = "submit_rewrite"
|
| 41 |
+
CONSULT_DBA = "consult_dba"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class EpisodePhase(StrEnum):
|
| 45 |
+
DIAGNOSE = "diagnose"
|
| 46 |
+
REWRITE = "rewrite"
|
| 47 |
+
DRIFT_RECOVERY = "drift_recovery"
|
| 48 |
+
FINALIZE = "finalize"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ToolErrorCode(StrEnum):
|
| 52 |
+
"""In-environment semantic failure codes (API contract).
|
| 53 |
+
|
| 54 |
+
Envelope-level `pydantic.ValidationError` is handled by the OpenEnv
|
| 55 |
+
transport layer (HTTP 422 / `/ws` error frame) and never reaches
|
| 56 |
+
`env.step`, so it has no code here.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
DB_ERROR = "db_error"
|
| 60 |
+
UNKNOWN_TABLE = "unknown_table"
|
| 61 |
+
QUERY_TIMEOUT = "query_timeout"
|
| 62 |
+
RESULT_TOO_LARGE = "result_too_large"
|
| 63 |
+
SUBMIT_BEFORE_DIAGNOSE = "submit_before_diagnose"
|
| 64 |
+
INVALID_TOOL_ARGUMENT = "invalid_tool_argument"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# Tool payloads (request side of `SqlDriftAction`)
|
| 69 |
+
# =============================================================================
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class _BasePayload(BaseModel):
|
| 73 |
+
"""Shared config for every tool-call payload."""
|
| 74 |
+
|
| 75 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ListTablesPayload(_BasePayload):
|
| 79 |
+
kind: Literal["list_tables"] = "list_tables"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class DescribeTablePayload(_BasePayload):
|
| 83 |
+
kind: Literal["describe_table"] = "describe_table"
|
| 84 |
+
table: str = Field(min_length=1, max_length=63)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class SampleRowsPayload(_BasePayload):
|
| 88 |
+
kind: Literal["sample_rows"] = "sample_rows"
|
| 89 |
+
table: str = Field(min_length=1, max_length=63)
|
| 90 |
+
limit: int = Field(default=5, ge=1, le=5)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class RunQueryPayload(_BasePayload):
|
| 94 |
+
kind: Literal["run_query"] = "run_query"
|
| 95 |
+
sql: str = Field(min_length=1, max_length=10_000)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ExplainQueryPayload(_BasePayload):
|
| 99 |
+
kind: Literal["explain_query"] = "explain_query"
|
| 100 |
+
sql: str = Field(min_length=1, max_length=10_000)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ReadChangelogPayload(_BasePayload):
|
| 104 |
+
kind: Literal["read_changelog"] = "read_changelog"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SubmitRewritePayload(_BasePayload):
|
| 108 |
+
kind: Literal["submit_rewrite"] = "submit_rewrite"
|
| 109 |
+
sql: str = Field(min_length=1, max_length=10_000)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class ConsultDBAPayload(_BasePayload):
|
| 113 |
+
kind: Literal["consult_dba"] = "consult_dba"
|
| 114 |
+
question: str = Field(min_length=1, max_length=400)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
ToolPayload = Annotated[
|
| 118 |
+
ListTablesPayload
|
| 119 |
+
| DescribeTablePayload
|
| 120 |
+
| SampleRowsPayload
|
| 121 |
+
| RunQueryPayload
|
| 122 |
+
| ExplainQueryPayload
|
| 123 |
+
| ReadChangelogPayload
|
| 124 |
+
| SubmitRewritePayload
|
| 125 |
+
| ConsultDBAPayload,
|
| 126 |
+
Field(discriminator="kind"),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Tool -> payload-kind mapping; single source of truth for cross-validation
|
| 131 |
+
# and for the server-side dispatcher in P7.
|
| 132 |
+
TOOL_TO_PAYLOAD_KIND: dict[ToolName, str] = {
|
| 133 |
+
ToolName.LIST_TABLES: "list_tables",
|
| 134 |
+
ToolName.DESCRIBE_TABLE: "describe_table",
|
| 135 |
+
ToolName.SAMPLE_ROWS: "sample_rows",
|
| 136 |
+
ToolName.RUN_QUERY: "run_query",
|
| 137 |
+
ToolName.EXPLAIN_QUERY: "explain_query",
|
| 138 |
+
ToolName.READ_CHANGELOG: "read_changelog",
|
| 139 |
+
ToolName.SUBMIT_REWRITE: "submit_rewrite",
|
| 140 |
+
ToolName.CONSULT_DBA: "consult_dba",
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# =============================================================================
|
| 145 |
+
# SqlDriftAction envelope
|
| 146 |
+
# =============================================================================
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SqlDriftAction(Action):
|
| 150 |
+
"""Tool-call envelope.
|
| 151 |
+
|
| 152 |
+
JSON wire format::
|
| 153 |
+
|
| 154 |
+
{"tool": "run_query", "payload": {"kind": "run_query", "sql": "..."}}
|
| 155 |
+
|
| 156 |
+
The `tool` field and `payload.kind` must agree; mismatch raises at
|
| 157 |
+
validation time.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
tool: ToolName
|
| 161 |
+
payload: ToolPayload
|
| 162 |
+
|
| 163 |
+
@model_validator(mode="after")
|
| 164 |
+
def _tool_matches_payload(self) -> SqlDriftAction:
|
| 165 |
+
expected = TOOL_TO_PAYLOAD_KIND[self.tool]
|
| 166 |
+
if self.payload.kind != expected:
|
| 167 |
+
# PydanticCustomError keeps ``ctx`` JSON-serializable (plain
|
| 168 |
+
# strings only), unlike a bare ``ValueError`` which Pydantic
|
| 169 |
+
# wraps with ``ctx={"error": ValueError(...)}`` and breaks
|
| 170 |
+
# FastAPI HTTPException JSON encoder (422 responses).
|
| 171 |
+
raise PydanticCustomError(
|
| 172 |
+
"tool_payload_mismatch",
|
| 173 |
+
"tool/payload mismatch: tool={tool} expects payload.kind={expected}, got {got}",
|
| 174 |
+
{
|
| 175 |
+
"tool": self.tool.value,
|
| 176 |
+
"expected": expected,
|
| 177 |
+
"got": self.payload.kind,
|
| 178 |
+
},
|
| 179 |
+
)
|
| 180 |
+
return self
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# =============================================================================
|
| 184 |
+
# Tool results (response side of `SqlDriftObservation.tool_result`)
|
| 185 |
+
# =============================================================================
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class _BaseResult(BaseModel):
|
| 189 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class ListTablesResult(_BaseResult):
|
| 193 |
+
kind: Literal["list_tables_result"] = "list_tables_result"
|
| 194 |
+
tables: list[str]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class DescribeTableResult(_BaseResult):
|
| 198 |
+
kind: Literal["describe_table_result"] = "describe_table_result"
|
| 199 |
+
table: str
|
| 200 |
+
columns: list[dict[str, str]] # [{"name": "...", "type": "..."}]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SampleRowsResult(_BaseResult):
|
| 204 |
+
kind: Literal["sample_rows_result"] = "sample_rows_result"
|
| 205 |
+
table: str
|
| 206 |
+
columns: list[str]
|
| 207 |
+
rows: list[list[Any]]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class RunQueryResult(_BaseResult):
|
| 211 |
+
kind: Literal["run_query_result"] = "run_query_result"
|
| 212 |
+
columns: list[str]
|
| 213 |
+
rows: list[list[Any]]
|
| 214 |
+
runtime_ms: float
|
| 215 |
+
row_count: int
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class ExplainQueryResult(_BaseResult):
|
| 219 |
+
kind: Literal["explain_query_result"] = "explain_query_result"
|
| 220 |
+
plan: str
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ReadChangelogResult(_BaseResult):
|
| 224 |
+
kind: Literal["read_changelog_result"] = "read_changelog_result"
|
| 225 |
+
entries: list[str]
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class SubmitRewriteResult(_BaseResult):
|
| 229 |
+
kind: Literal["submit_rewrite_result"] = "submit_rewrite_result"
|
| 230 |
+
accepted: bool
|
| 231 |
+
runtime_ms: float
|
| 232 |
+
matches_ground_truth: bool
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ConsultDBAResult(_BaseResult):
|
| 236 |
+
kind: Literal["consult_dba_result"] = "consult_dba_result"
|
| 237 |
+
tier: int = Field(ge=1, le=3)
|
| 238 |
+
hint: str
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class ToolError(_BaseResult):
|
| 242 |
+
kind: Literal["tool_error"] = "tool_error"
|
| 243 |
+
code: ToolErrorCode
|
| 244 |
+
message: str = Field(max_length=2_000)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
ToolResult = Annotated[
|
| 248 |
+
ListTablesResult
|
| 249 |
+
| DescribeTableResult
|
| 250 |
+
| SampleRowsResult
|
| 251 |
+
| RunQueryResult
|
| 252 |
+
| ExplainQueryResult
|
| 253 |
+
| ReadChangelogResult
|
| 254 |
+
| SubmitRewriteResult
|
| 255 |
+
| ConsultDBAResult
|
| 256 |
+
| ToolError,
|
| 257 |
+
Field(discriminator="kind"),
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# The six reward-component keys match the composed rubric; tests and telemetry
|
| 262 |
+
# rely on this exact schema.
|
| 263 |
+
REWARD_COMPONENT_KEYS: tuple[str, ...] = (
|
| 264 |
+
"r_correct",
|
| 265 |
+
"r_drift",
|
| 266 |
+
"r_speedup",
|
| 267 |
+
"r_step_tax",
|
| 268 |
+
"r_gatekeepers",
|
| 269 |
+
"r_consult_dba",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# =============================================================================
|
| 274 |
+
# SqlDriftObservation
|
| 275 |
+
# =============================================================================
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _zero_reward_components() -> dict[str, float]:
|
| 279 |
+
"""Six-key reward envelope initialised to zero.
|
| 280 |
+
|
| 281 |
+
Every observation, including the reset observation, carries the full
|
| 282 |
+
six-key schema so telemetry and tests can index it unconditionally.
|
| 283 |
+
"""
|
| 284 |
+
return {key: 0.0 for key in REWARD_COMPONENT_KEYS}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class SqlDriftObservation(Observation):
|
| 288 |
+
"""Observation returned by :meth:`SqlDriftEnvironment.step`.
|
| 289 |
+
|
| 290 |
+
Inherits `done: bool` and `reward: float | None` from base Observation.
|
| 291 |
+
|
| 292 |
+
The task payload (`baseline_sql`, `schema_synopsis`) is delivered on
|
| 293 |
+
the reset observation and kept empty on subsequent steps: the agent
|
| 294 |
+
is expected to capture it once and hold it in its own context.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
step: int = Field(ge=0)
|
| 298 |
+
phase: EpisodePhase
|
| 299 |
+
last_tool: ToolName | None = None
|
| 300 |
+
tool_result: ToolResult | None = None
|
| 301 |
+
drift_fired: bool = False
|
| 302 |
+
drift_acknowledged: bool = False
|
| 303 |
+
learned_hints: str = Field(default="", max_length=800)
|
| 304 |
+
baseline_sql: str = Field(default="", max_length=10_000)
|
| 305 |
+
schema_synopsis: str = Field(default="", max_length=2_000)
|
| 306 |
+
budget_steps_remaining: int = Field(ge=0)
|
| 307 |
+
reward_components: dict[str, float] = Field(default_factory=_zero_reward_components)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# =============================================================================
|
| 311 |
+
# SqlDriftState — PUBLIC state (sanitized)
|
| 312 |
+
# =============================================================================
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class SqlDriftState(State):
|
| 316 |
+
"""Public state snapshot — serialized over `/state`.
|
| 317 |
+
|
| 318 |
+
Ground truth, DB handles, seeds, and baseline SQL live in
|
| 319 |
+
:class:`engine.runtime.RuntimeEpisodeState` and are never exposed here.
|
| 320 |
+
`extra="forbid"` guarantees no accidental leak via future field additions.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
model_config = ConfigDict(
|
| 324 |
+
extra="forbid",
|
| 325 |
+
validate_assignment=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
scenario_id: str
|
| 329 |
+
phase: EpisodePhase
|
| 330 |
+
budget_steps_remaining: int = Field(ge=0)
|
| 331 |
+
drift_fired: bool = False
|
| 332 |
+
consultations_used: int = Field(default=0, ge=0)
|
| 333 |
+
submitted: bool = False
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
__all__ = [
|
| 337 |
+
"ConsultDBAPayload",
|
| 338 |
+
"ConsultDBAResult",
|
| 339 |
+
"DescribeTablePayload",
|
| 340 |
+
"DescribeTableResult",
|
| 341 |
+
"EpisodePhase",
|
| 342 |
+
"ExplainQueryPayload",
|
| 343 |
+
"ExplainQueryResult",
|
| 344 |
+
"ListTablesPayload",
|
| 345 |
+
"ListTablesResult",
|
| 346 |
+
"REWARD_COMPONENT_KEYS",
|
| 347 |
+
"ReadChangelogPayload",
|
| 348 |
+
"ReadChangelogResult",
|
| 349 |
+
"RunQueryPayload",
|
| 350 |
+
"RunQueryResult",
|
| 351 |
+
"SampleRowsPayload",
|
| 352 |
+
"SampleRowsResult",
|
| 353 |
+
"SqlDriftAction",
|
| 354 |
+
"SqlDriftObservation",
|
| 355 |
+
"SqlDriftState",
|
| 356 |
+
"SubmitRewritePayload",
|
| 357 |
+
"SubmitRewriteResult",
|
| 358 |
+
"TOOL_TO_PAYLOAD_KIND",
|
| 359 |
+
"ToolError",
|
| 360 |
+
"ToolErrorCode",
|
| 361 |
+
"ToolName",
|
| 362 |
+
"ToolPayload",
|
| 363 |
+
"ToolResult",
|
| 364 |
+
]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: sql_drift_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "sql_drift_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "SQLDrift: OpenEnv gym for repairing and optimizing SQL under live schema/business-rule drift"
|
| 15 |
+
requires-python = ">=3.12,<3.14"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"duckdb>=1.5.2,<2.0",
|
| 18 |
+
"huggingface-hub>=0.20,<1.0",
|
| 19 |
+
"openenv-core[core]>=0.2.2,<0.4",
|
| 20 |
+
"sqlglot>=30.6.0,<40.0",
|
| 21 |
+
"pydantic>=2.8.0,<3.0",
|
| 22 |
+
"python-dotenv>=1.2.2,<2.0",
|
| 23 |
+
"openai>=2.32.0,<3.0",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[project.optional-dependencies]
|
| 27 |
+
evidence = [
|
| 28 |
+
"matplotlib>=3.8.0,<4.0",
|
| 29 |
+
"pandas>=2.0.0,<3.0",
|
| 30 |
+
]
|
| 31 |
+
dev = [
|
| 32 |
+
"mypy>=1.20.1",
|
| 33 |
+
"pytest>=9.0.3",
|
| 34 |
+
"pytest-asyncio>=1.3.0",
|
| 35 |
+
"pytest-cov>=7.0.0",
|
| 36 |
+
"ruff>=0.15.11",
|
| 37 |
+
"httpx>=0.28.0",
|
| 38 |
+
]
|
| 39 |
+
train = [
|
| 40 |
+
"trl>=0.25.0,<1.0",
|
| 41 |
+
"datasets>=2.20.0,<4.0",
|
| 42 |
+
"transformers>=4.46.0,<5.0",
|
| 43 |
+
"accelerate>=1.0.0,<2.0",
|
| 44 |
+
"peft>=0.13,<1.0",
|
| 45 |
+
"bitsandbytes>=0.43,<1.0",
|
| 46 |
+
"unsloth>=2024.9,<2027.0",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
[project.scripts]
|
| 50 |
+
# Flat package layout; see _cli.py. Runtime path: site-packages/sql_drift_env/
|
| 51 |
+
server = "sql_drift_env._cli:main"
|
| 52 |
+
|
| 53 |
+
[tool.setuptools]
|
| 54 |
+
include-package-data = true
|
| 55 |
+
packages = [
|
| 56 |
+
"sql_drift_env",
|
| 57 |
+
"sql_drift_env.scenarios",
|
| 58 |
+
"sql_drift_env.engine",
|
| 59 |
+
"sql_drift_env.skill_library",
|
| 60 |
+
"sql_drift_env.actors",
|
| 61 |
+
"sql_drift_env.training",
|
| 62 |
+
"sql_drift_env.server",
|
| 63 |
+
"sql_drift_env.utilities",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
[tool.setuptools.package-dir]
|
| 67 |
+
"sql_drift_env" = "."
|
| 68 |
+
"sql_drift_env.scenarios" = "scenarios"
|
| 69 |
+
"sql_drift_env.engine" = "engine"
|
| 70 |
+
"sql_drift_env.skill_library" = "skill_library"
|
| 71 |
+
"sql_drift_env.actors" = "actors"
|
| 72 |
+
"sql_drift_env.training" = "training"
|
| 73 |
+
"sql_drift_env.server" = "server"
|
| 74 |
+
"sql_drift_env.utilities" = "utilities"
|
| 75 |
+
|
| 76 |
+
[tool.pytest.ini_options]
|
| 77 |
+
testpaths = ["tests"]
|
| 78 |
+
asyncio_mode = "auto"
|
| 79 |
+
markers = [
|
| 80 |
+
"slow: marks tests as slow (excluded by default; run via `pytest -m slow`)",
|
| 81 |
+
]
|
| 82 |
+
# Default excludes slow tests so CI stays on a fast path; run `pytest -m slow` for full suite.
|
| 83 |
+
addopts = "-ra --strict-markers -m 'not slow'"
|
| 84 |
+
|
| 85 |
+
[tool.ruff]
|
| 86 |
+
line-length = 100
|
| 87 |
+
target-version = "py312"
|
| 88 |
+
extend-exclude = [".venv", "design", "*.egg-info", "*.ipynb"]
|
| 89 |
+
|
| 90 |
+
[tool.ruff.lint]
|
| 91 |
+
select = ["E", "F", "I", "B", "UP", "SIM"]
|
| 92 |
+
ignore = ["E501"]
|
| 93 |
+
|
| 94 |
+
# Prefer fixing structure over scattered inline suppressions. See each rule in Ruff docs.
|
| 95 |
+
[tool.ruff.lint.per-file-ignores]
|
| 96 |
+
"_cli.py" = ["PLC0415"]
|
| 97 |
+
"training/grpo_train.py" = ["PLC0415"]
|
| 98 |
+
"engine/profiler.py" = ["BLE001"]
|
| 99 |
+
"scenarios/__init__.py" = ["F401"]
|
| 100 |
+
"tests/unit/test_p0_smoke.py" = ["F401"]
|
| 101 |
+
"utilities/verbose_api_rollout.py" = ["E402"]
|
| 102 |
+
"utilities/demo_rollout.py" = ["E402"]
|
| 103 |
+
"tests/unit/test_drift_scenarios.py" = ["E402"]
|
| 104 |
+
"tests/unit/test_profiler.py" = ["E402"]
|
| 105 |
+
"tests/unit/test_drift.py" = ["E402"]
|
| 106 |
+
"tests/unit/test_runtime.py" = ["E402"]
|
| 107 |
+
|
| 108 |
+
[tool.mypy]
|
| 109 |
+
python_version = "3.12"
|
| 110 |
+
strict = true
|
| 111 |
+
ignore_missing_imports = true
|
| 112 |
+
explicit_package_bases = true
|
| 113 |
+
mypy_path = "."
|
| 114 |
+
files = [
|
| 115 |
+
"_cli.py",
|
| 116 |
+
"actors",
|
| 117 |
+
"client.py",
|
| 118 |
+
"engine",
|
| 119 |
+
"models.py",
|
| 120 |
+
"scenarios",
|
| 121 |
+
"utilities",
|
| 122 |
+
"server",
|
| 123 |
+
"skill_library",
|
| 124 |
+
"training",
|
| 125 |
+
]
|
| 126 |
+
exclude = [
|
| 127 |
+
"\\.venv",
|
| 128 |
+
"design",
|
| 129 |
+
".*\\.egg-info",
|
| 130 |
+
"tests",
|
| 131 |
+
"build",
|
| 132 |
+
"dist",
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# Subclasses in these modules call into openenv-core (no stubs). Relax only here.
|
| 136 |
+
[[tool.mypy.overrides]]
|
| 137 |
+
module = [
|
| 138 |
+
"client",
|
| 139 |
+
"models",
|
| 140 |
+
"engine.reward",
|
| 141 |
+
"server.app",
|
| 142 |
+
"server.sql_drift_env_environment",
|
| 143 |
+
"training.grpo_train",
|
| 144 |
+
]
|
| 145 |
+
disallow_subclassing_any = false
|
| 146 |
+
warn_return_any = false
|
| 147 |
+
|
| 148 |
+
[tool.coverage.run]
|
| 149 |
+
# pytest-cov can hit upstream issues with beartype when tracing; run plain pytest for CI.
|
| 150 |
+
source = ["engine", "scenarios", "skill_library", "actors", "models.py"]
|
| 151 |
+
omit = ["tests/*", ".venv/*"]
|
| 152 |
+
|
| 153 |
+
[tool.coverage.report]
|
| 154 |
+
skip_empty = true
|
| 155 |
+
exclude_lines = [
|
| 156 |
+
"pragma: no cover",
|
| 157 |
+
"if __name__ == \"__main__\":",
|
| 158 |
+
"if TYPE_CHECKING:",
|
| 159 |
+
"raise NotImplementedError",
|
| 160 |
+
"if 0:",
|
| 161 |
+
"if False:",
|
| 162 |
+
"\\.\\.\\.",
|
| 163 |
+
]
|
scenarios/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario registry + auto-discovery.
|
| 2 |
+
|
| 3 |
+
Every sibling module whose name starts with ``scenario_`` is expected to
|
| 4 |
+
export a module-level ``SPEC: ScenarioSpec``. This module walks the package
|
| 5 |
+
at import time and populates :data:`REGISTRY` so callers (env reset, tests,
|
| 6 |
+
CLI eval) can look scenarios up by id without knowing the file layout.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import importlib
|
| 12 |
+
import pkgutil
|
| 13 |
+
|
| 14 |
+
from .base import (
|
| 15 |
+
DriftConfig,
|
| 16 |
+
DriftKind,
|
| 17 |
+
Family,
|
| 18 |
+
ScenarioInstance,
|
| 19 |
+
ScenarioSpec,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
REGISTRY: dict[str, ScenarioSpec] = {}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _discover_specs() -> None:
|
| 26 |
+
"""Import every `scenario_*` sibling module and harvest their SPEC."""
|
| 27 |
+
package = __name__ # "scenarios"
|
| 28 |
+
package_path = __path__ # provided by Python's package machinery
|
| 29 |
+
|
| 30 |
+
for info in pkgutil.iter_modules(package_path):
|
| 31 |
+
if not info.name.startswith("scenario_"):
|
| 32 |
+
continue
|
| 33 |
+
module = importlib.import_module(f"{package}.{info.name}")
|
| 34 |
+
spec = getattr(module, "SPEC", None)
|
| 35 |
+
if spec is None:
|
| 36 |
+
raise RuntimeError(f"{package}.{info.name} is missing a module-level `SPEC` export")
|
| 37 |
+
if not isinstance(spec, ScenarioSpec):
|
| 38 |
+
raise TypeError(
|
| 39 |
+
f"{package}.{info.name}.SPEC is {type(spec).__name__}; expected ScenarioSpec"
|
| 40 |
+
)
|
| 41 |
+
if spec.scenario_id in REGISTRY:
|
| 42 |
+
raise RuntimeError(
|
| 43 |
+
f"duplicate scenario_id {spec.scenario_id!r} — "
|
| 44 |
+
f"already registered from {REGISTRY[spec.scenario_id]!r}"
|
| 45 |
+
)
|
| 46 |
+
REGISTRY[spec.scenario_id] = spec
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def iter_specs() -> list[ScenarioSpec]:
|
| 50 |
+
return sorted(REGISTRY.values(), key=lambda s: s.scenario_id)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_spec(scenario_id: str) -> ScenarioSpec:
|
| 54 |
+
try:
|
| 55 |
+
return REGISTRY[scenario_id]
|
| 56 |
+
except KeyError as e:
|
| 57 |
+
raise KeyError(f"unknown scenario_id={scenario_id!r}; known: {sorted(REGISTRY)}") from e
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
_discover_specs()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
__all__ = [
|
| 64 |
+
"DriftConfig",
|
| 65 |
+
"DriftKind",
|
| 66 |
+
"Family",
|
| 67 |
+
"REGISTRY",
|
| 68 |
+
"ScenarioInstance",
|
| 69 |
+
"ScenarioSpec",
|
| 70 |
+
"get_spec",
|
| 71 |
+
"iter_specs",
|
| 72 |
+
]
|
scenarios/_fixtures.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic fixture generators (stdlib only).
|
| 2 |
+
|
| 3 |
+
All distributions are backed by :class:`random.Random(seed)` so a given
|
| 4 |
+
``(scenario_id, seed, scale)`` tuple always yields the same table contents.
|
| 5 |
+
|
| 6 |
+
Exposes a single :func:`seeded_rng` factory plus a handful of domain-specific
|
| 7 |
+
generators used by the concrete scenarios. No numpy/pandas runtime deps.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import random
|
| 14 |
+
from collections.abc import Sequence
|
| 15 |
+
from datetime import UTC
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def seeded_rng(*parts: int | str) -> random.Random:
|
| 19 |
+
"""Derive a deterministic `random.Random` from mixed scalar parts.
|
| 20 |
+
|
| 21 |
+
Uses a stable 64-bit SplitMix-style hash over ``repr(parts)`` — avoids
|
| 22 |
+
Python's per-interpreter salted ``hash()`` for str.
|
| 23 |
+
"""
|
| 24 |
+
h = 1469598103934665603 # FNV-1a 64-bit offset basis
|
| 25 |
+
for p in parts:
|
| 26 |
+
for byte in repr(p).encode():
|
| 27 |
+
h ^= byte
|
| 28 |
+
h = (h * 1099511628211) & 0xFFFFFFFFFFFFFFFF
|
| 29 |
+
return random.Random(h)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def lognormal_amounts(rng: random.Random, n: int, mu: float, sigma: float) -> list[float]:
|
| 33 |
+
"""n log-normally distributed positive amounts, rounded to 2dp."""
|
| 34 |
+
return [round(rng.lognormvariate(mu, sigma), 2) for _ in range(n)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def zipfian_choices(
|
| 38 |
+
rng: random.Random, population: Sequence[int], n: int, *, alpha: float = 1.07
|
| 39 |
+
) -> list[int]:
|
| 40 |
+
"""n draws from `population` with zipf(alpha) weights.
|
| 41 |
+
|
| 42 |
+
`alpha` controls skew: 1.07 is the project default. Uses
|
| 43 |
+
`random.choices` with explicit weights, so the draws are stable across
|
| 44 |
+
Python versions for a given `rng` state.
|
| 45 |
+
"""
|
| 46 |
+
weights = [1.0 / ((i + 1) ** alpha) for i in range(len(population))]
|
| 47 |
+
return rng.choices(list(population), weights=weights, k=n)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def date_range_epoch_ms(
|
| 51 |
+
rng: random.Random,
|
| 52 |
+
n: int,
|
| 53 |
+
*,
|
| 54 |
+
start_epoch_ms: int,
|
| 55 |
+
window_days: int,
|
| 56 |
+
) -> list[int]:
|
| 57 |
+
"""n random timestamps (ms) within `[start, start + window_days)`."""
|
| 58 |
+
span_ms = window_days * 86_400_000
|
| 59 |
+
return [start_epoch_ms + rng.randrange(span_ms) for _ in range(n)]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def iso_strings_from_epoch_ms(epoch_ms: list[int]) -> list[str]:
|
| 63 |
+
"""Convert epoch ms to ISO-8601 UTC strings (matching DuckDB's native coerce)."""
|
| 64 |
+
from datetime import datetime
|
| 65 |
+
|
| 66 |
+
return [
|
| 67 |
+
datetime.fromtimestamp(t / 1000, tz=UTC).isoformat().replace("+00:00", "Z")
|
| 68 |
+
for t in epoch_ms
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def categorical_choices(
|
| 73 |
+
rng: random.Random,
|
| 74 |
+
categories: Sequence[str],
|
| 75 |
+
n: int,
|
| 76 |
+
*,
|
| 77 |
+
weights: Sequence[float] | None = None,
|
| 78 |
+
) -> list[str]:
|
| 79 |
+
return rng.choices(list(categories), weights=list(weights) if weights else None, k=n)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def unique_names(rng: random.Random, n: int, *, prefix: str = "name") -> list[str]:
|
| 83 |
+
"""Stable pseudo-unique string IDs of the form `<prefix>_<64-bit-hex>`."""
|
| 84 |
+
return [f"{prefix}_{rng.getrandbits(64):016x}" for _ in range(n)]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def approx_normal(
|
| 88 |
+
rng: random.Random, n: int, *, mu: float, sigma: float, clip_lo: float | None = None
|
| 89 |
+
) -> list[float]:
|
| 90 |
+
"""n normal draws, optionally clipped below."""
|
| 91 |
+
out: list[float] = []
|
| 92 |
+
for _ in range(n):
|
| 93 |
+
x = rng.gauss(mu, sigma)
|
| 94 |
+
if clip_lo is not None and x < clip_lo:
|
| 95 |
+
x = clip_lo
|
| 96 |
+
out.append(round(x, 4))
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def sanity_nonzero_variance(xs: Sequence[float]) -> bool:
|
| 101 |
+
"""Guard: reject obviously degenerate distributions (used in smoke tests)."""
|
| 102 |
+
if not xs:
|
| 103 |
+
return False
|
| 104 |
+
mean = sum(xs) / len(xs)
|
| 105 |
+
var = sum((x - mean) ** 2 for x in xs) / len(xs)
|
| 106 |
+
return var > 1e-12 and not math.isnan(var)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
__all__ = [
|
| 110 |
+
"approx_normal",
|
| 111 |
+
"categorical_choices",
|
| 112 |
+
"date_range_epoch_ms",
|
| 113 |
+
"iso_strings_from_epoch_ms",
|
| 114 |
+
"lognormal_amounts",
|
| 115 |
+
"sanity_nonzero_variance",
|
| 116 |
+
"seeded_rng",
|
| 117 |
+
"unique_names",
|
| 118 |
+
"zipfian_choices",
|
| 119 |
+
]
|
scenarios/base.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario spec + instance primitives.
|
| 2 |
+
|
| 3 |
+
Each concrete scenario file in :mod:`scenarios` exports:
|
| 4 |
+
|
| 5 |
+
- ``SPEC: ScenarioSpec`` — the immutable metadata (id, family, tags,
|
| 6 |
+
optional drift config) plus a bound ``builder`` callable.
|
| 7 |
+
|
| 8 |
+
The builder takes ``(spec, seed, scale)`` and returns a ready-to-attach
|
| 9 |
+
:class:`ScenarioInstance` whose DuckDB connection has been loaded with
|
| 10 |
+
deterministic fixtures, ground-truth hashes pre-computed, and baseline
|
| 11 |
+
runtime measured. ``base_scale`` is author-tuned per scenario so the
|
| 12 |
+
measured baseline clears :data:`BASELINE_MIN_MS` on a single build —
|
| 13 |
+
the old timing-driven reroll loop was removed because it coupled the
|
| 14 |
+
fixture RNG seed to the retry count, which destroyed determinism
|
| 15 |
+
whenever CI hit a jitter-induced retry.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from collections.abc import Callable
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
import duckdb
|
| 26 |
+
|
| 27 |
+
from utilities.logger import get_module_logger
|
| 28 |
+
|
| 29 |
+
_LOG = get_module_logger(__name__)
|
| 30 |
+
|
| 31 |
+
Family = Literal["ecommerce", "events", "cms", "saas_logs", "multitenant"]
|
| 32 |
+
DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"]
|
| 33 |
+
DifficultyLevel = Literal["easy", "normal", "hard"]
|
| 34 |
+
|
| 35 |
+
# Baseline floor — empirically calibrated, not a fixed marketing target
|
| 36 |
+
# 50 ms. DuckDB on in-memory fixtures at CI-reasonable scales (a few
|
| 37 |
+
# tens of thousands of rows) measures warm baselines of 0.4–2 ms on
|
| 38 |
+
# the anti-pattern queries shipped here; reaching 50 ms would require
|
| 39 |
+
# multi-minute fixture builds per scenario, which is untenable for
|
| 40 |
+
# both CI and RL rollouts (every reset rebuilds).
|
| 41 |
+
#
|
| 42 |
+
# 0.3 ms is ~3–5× the median-of-3 warm jitter floor on a quiet CPU
|
| 43 |
+
# (observed jitter ~60–100 µs). This SNR is tight but workable because
|
| 44 |
+
# the rubric gates the speedup reward at 1.2× before any
|
| 45 |
+
# credit is issued, so jitter-induced near-1× "speedups" score zero.
|
| 46 |
+
# The cap at 64× bounds upside. A 2× rewrite against a 0.3 ms baseline
|
| 47 |
+
# lands at 0.15 ms — still distinguishable from jitter under
|
| 48 |
+
# median-of-3 smoothing.
|
| 49 |
+
#
|
| 50 |
+
# The same floor applies in production and CI — no env-var escape
|
| 51 |
+
# hatch — so tests exercise the real reward distribution. Per-scenario
|
| 52 |
+
# overrides may raise *or* lower this floor when a scenario's query
|
| 53 |
+
# shape has a different natural baseline (see the field docstring on
|
| 54 |
+
# :class:`ScenarioSpec.baseline_min_ms`).
|
| 55 |
+
BASELINE_MIN_MS = 0.3
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass(frozen=True)
|
| 59 |
+
class DriftConfig:
|
| 60 |
+
kind: DriftKind
|
| 61 |
+
payload: dict[str, Any]
|
| 62 |
+
min_step: int = 6
|
| 63 |
+
max_step: int = 12
|
| 64 |
+
cooldown_steps: int = 2
|
| 65 |
+
|
| 66 |
+
def __post_init__(self) -> None:
|
| 67 |
+
if self.min_step < 1:
|
| 68 |
+
raise ValueError("min_step must be >= 1")
|
| 69 |
+
if self.max_step < self.min_step:
|
| 70 |
+
raise ValueError("max_step must be >= min_step")
|
| 71 |
+
if self.cooldown_steps < 0:
|
| 72 |
+
raise ValueError("cooldown_steps must be >= 0")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ScenarioInstance:
|
| 77 |
+
"""Concretized scenario — ready-to-attach DuckDB fixture + ground truths."""
|
| 78 |
+
|
| 79 |
+
conn: duckdb.DuckDBPyConnection
|
| 80 |
+
baseline_sql: str
|
| 81 |
+
gt_sql_predrift: str
|
| 82 |
+
gt_sql_postdrift: str | None
|
| 83 |
+
baseline_runtime_ms: float
|
| 84 |
+
baseline_tokens: int
|
| 85 |
+
gt_result_hash_predrift: str
|
| 86 |
+
gt_result_hash_postdrift: str | None
|
| 87 |
+
drift_config: DriftConfig | None
|
| 88 |
+
schema_synopsis: str
|
| 89 |
+
# Drift-distinctive identifier sets consumed by the drift-adapt
|
| 90 |
+
# rubric. ``postdrift_identifiers`` marks identifiers/literals
|
| 91 |
+
# the correct post-drift rewrite MUST introduce; ``predrift_identifiers``
|
| 92 |
+
# marks identifiers/literals a submission that ignored the drift
|
| 93 |
+
# WOULD retain. Together they let the rubric distinguish "adapted"
|
| 94 |
+
# from "did not adapt" for drift kinds where a single identifier
|
| 95 |
+
# (e.g. ``ts`` under date-format drift) is shared by both sides.
|
| 96 |
+
postdrift_identifiers: frozenset[str] = field(default_factory=frozenset)
|
| 97 |
+
predrift_identifiers: frozenset[str] = field(default_factory=frozenset)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Builder signature: (spec, seed, scale) -> (conn, baseline_sql,
|
| 101 |
+
# gt_sql_predrift, gt_sql_postdrift, schema_synopsis,
|
| 102 |
+
# postdrift_identifiers, predrift_identifiers).
|
| 103 |
+
BuilderResult = tuple[
|
| 104 |
+
"duckdb.DuckDBPyConnection",
|
| 105 |
+
str, # baseline_sql
|
| 106 |
+
str, # gt_sql_predrift
|
| 107 |
+
str | None, # gt_sql_postdrift
|
| 108 |
+
str, # schema_synopsis
|
| 109 |
+
frozenset[str], # postdrift_identifiers
|
| 110 |
+
frozenset[str], # predrift_identifiers
|
| 111 |
+
]
|
| 112 |
+
BuilderFn = Callable[["ScenarioSpec", int, int], BuilderResult]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass(frozen=True)
|
| 116 |
+
class ScenarioSpec:
|
| 117 |
+
"""Immutable scenario metadata + bound builder."""
|
| 118 |
+
|
| 119 |
+
scenario_id: str
|
| 120 |
+
family: Family
|
| 121 |
+
tags: frozenset[str]
|
| 122 |
+
drift_config: DriftConfig | None
|
| 123 |
+
builder: BuilderFn
|
| 124 |
+
# Row-count scale passed to the builder. Author-tuned so the
|
| 125 |
+
# measured baseline clears ``baseline_min_ms`` on a single build;
|
| 126 |
+
# materialize() emits a warning (but does not retry) if the floor
|
| 127 |
+
# is not met, signalling the author to bump this value.
|
| 128 |
+
base_scale: int = 1_000
|
| 129 |
+
# Per-scenario baseline floor override. Most scenarios inherit the
|
| 130 |
+
# module default. Scenarios whose query shape naturally lands at a
|
| 131 |
+
# very different baseline (e.g. a trivial single-table GROUP BY
|
| 132 |
+
# that can't be meaningfully sped up, or a large join whose raw
|
| 133 |
+
# shape is already expensive) can pin a different floor with a
|
| 134 |
+
# documented rationale at the SPEC site.
|
| 135 |
+
baseline_min_ms: float = BASELINE_MIN_MS
|
| 136 |
+
|
| 137 |
+
def materialize(self, seed: int, *, difficulty: DifficultyLevel = "normal") -> ScenarioInstance:
|
| 138 |
+
return materialize(self, seed, difficulty=difficulty)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def count_tokens(sql: str) -> int:
|
| 142 |
+
"""Rough whitespace/punctuation token count — good enough for baseline."""
|
| 143 |
+
import re
|
| 144 |
+
|
| 145 |
+
return len(re.findall(r"[\w]+|[^\s\w]", sql))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _scale_for_difficulty(base_scale: int, difficulty: DifficultyLevel) -> int:
|
| 149 |
+
"""Map a coarse difficulty level onto the scenario builder's row-count scale."""
|
| 150 |
+
if difficulty == "easy":
|
| 151 |
+
return max(1, base_scale // 2)
|
| 152 |
+
if difficulty == "hard":
|
| 153 |
+
return base_scale * 2
|
| 154 |
+
return base_scale
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def materialize(
|
| 158 |
+
spec: ScenarioSpec, seed: int, *, difficulty: DifficultyLevel = "normal"
|
| 159 |
+
) -> ScenarioInstance:
|
| 160 |
+
"""Build a ScenarioInstance once, measure baseline, and return.
|
| 161 |
+
|
| 162 |
+
Single build — deterministic, no retry. If the measured baseline is
|
| 163 |
+
below ``spec.baseline_min_ms`` a warning is logged so scenario
|
| 164 |
+
authors can bump ``base_scale``; the instance is still returned so
|
| 165 |
+
episodes can proceed (the rubric gracefully handles small
|
| 166 |
+
baselines via the 1.2× speedup gate and infinite-speedup cap).
|
| 167 |
+
"""
|
| 168 |
+
from engine.profiler import median_of_3_warm_ms
|
| 169 |
+
from engine.verifier import canonical_row_hash
|
| 170 |
+
|
| 171 |
+
scale = _scale_for_difficulty(spec.base_scale, difficulty)
|
| 172 |
+
|
| 173 |
+
(
|
| 174 |
+
conn,
|
| 175 |
+
baseline_sql,
|
| 176 |
+
gt_pre,
|
| 177 |
+
gt_post,
|
| 178 |
+
synopsis,
|
| 179 |
+
postdrift_ids,
|
| 180 |
+
predrift_ids,
|
| 181 |
+
) = spec.builder(spec, seed, scale)
|
| 182 |
+
try:
|
| 183 |
+
baseline_ms = median_of_3_warm_ms(conn, baseline_sql)
|
| 184 |
+
except Exception:
|
| 185 |
+
conn.close()
|
| 186 |
+
raise
|
| 187 |
+
if baseline_ms < spec.baseline_min_ms:
|
| 188 |
+
_LOG.warning(
|
| 189 |
+
"%s: baseline %.2fms < %.2fms floor at difficulty=%s scale=%d — bump base_scale",
|
| 190 |
+
spec.scenario_id,
|
| 191 |
+
baseline_ms,
|
| 192 |
+
spec.baseline_min_ms,
|
| 193 |
+
difficulty,
|
| 194 |
+
scale,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
pre_rows = conn.execute(gt_pre).fetchall()
|
| 198 |
+
gt_hash_pre = canonical_row_hash(pre_rows)
|
| 199 |
+
# Post-drift ground-truth hashes are computed AFTER drift is applied
|
| 200 |
+
# at runtime — not here. The env backfills them from gt_post once
|
| 201 |
+
# drift fires.
|
| 202 |
+
return ScenarioInstance(
|
| 203 |
+
conn=conn,
|
| 204 |
+
baseline_sql=baseline_sql,
|
| 205 |
+
gt_sql_predrift=gt_pre,
|
| 206 |
+
gt_sql_postdrift=gt_post,
|
| 207 |
+
baseline_runtime_ms=baseline_ms,
|
| 208 |
+
baseline_tokens=count_tokens(baseline_sql),
|
| 209 |
+
gt_result_hash_predrift=gt_hash_pre,
|
| 210 |
+
gt_result_hash_postdrift=None,
|
| 211 |
+
drift_config=spec.drift_config,
|
| 212 |
+
schema_synopsis=synopsis,
|
| 213 |
+
postdrift_identifiers=postdrift_ids,
|
| 214 |
+
predrift_identifiers=predrift_ids,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
__all__ = [
|
| 219 |
+
"BASELINE_MIN_MS",
|
| 220 |
+
"BuilderFn",
|
| 221 |
+
"BuilderResult",
|
| 222 |
+
"DifficultyLevel",
|
| 223 |
+
"DriftConfig",
|
| 224 |
+
"DriftKind",
|
| 225 |
+
"Family",
|
| 226 |
+
"ScenarioInstance",
|
| 227 |
+
"ScenarioSpec",
|
| 228 |
+
"count_tokens",
|
| 229 |
+
"materialize",
|
| 230 |
+
]
|
scenarios/scenario_01_correlated_subquery.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 01 — correlated subquery anti-pattern.
|
| 2 |
+
|
| 3 |
+
Baseline computes a per-user order-count via a correlated subquery in the
|
| 4 |
+
projection list. The rewrite replaces it with a LEFT JOIN on a grouped
|
| 5 |
+
aggregate so the query runs once instead of once per outer row.
|
| 6 |
+
|
| 7 |
+
Schema: users(id, signup_month, tier), orders(id, user_id, amount, status).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import (
|
| 15 |
+
categorical_choices,
|
| 16 |
+
lognormal_amounts,
|
| 17 |
+
seeded_rng,
|
| 18 |
+
zipfian_choices,
|
| 19 |
+
)
|
| 20 |
+
from .base import BuilderResult, ScenarioSpec
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 24 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 25 |
+
n_users = scale
|
| 26 |
+
n_orders = scale * 6
|
| 27 |
+
|
| 28 |
+
user_ids = list(range(1, n_users + 1))
|
| 29 |
+
tiers = categorical_choices(
|
| 30 |
+
rng, ["free", "pro", "business"], n_users, weights=[0.7, 0.25, 0.05]
|
| 31 |
+
)
|
| 32 |
+
signup_months = rng.choices(list(range(1, 13)), k=n_users)
|
| 33 |
+
|
| 34 |
+
order_user_ids = zipfian_choices(rng, user_ids, n_orders)
|
| 35 |
+
amounts = lognormal_amounts(rng, n_orders, mu=3.2, sigma=0.9)
|
| 36 |
+
statuses = categorical_choices(
|
| 37 |
+
rng,
|
| 38 |
+
["placed", "fulfilled", "refunded", "cancelled"],
|
| 39 |
+
n_orders,
|
| 40 |
+
weights=[0.6, 0.3, 0.05, 0.05],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
conn = duckdb.connect(":memory:")
|
| 44 |
+
conn.execute("CREATE TABLE users( id BIGINT PRIMARY KEY, signup_month INTEGER, tier VARCHAR);")
|
| 45 |
+
conn.execute(
|
| 46 |
+
"CREATE TABLE orders("
|
| 47 |
+
" id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE, status VARCHAR);"
|
| 48 |
+
)
|
| 49 |
+
conn.executemany(
|
| 50 |
+
"INSERT INTO users VALUES (?, ?, ?)",
|
| 51 |
+
list(zip(user_ids, signup_months, tiers, strict=False)),
|
| 52 |
+
)
|
| 53 |
+
conn.executemany(
|
| 54 |
+
"INSERT INTO orders VALUES (?, ?, ?, ?)",
|
| 55 |
+
[
|
| 56 |
+
(oid, uid, amt, st)
|
| 57 |
+
for oid, (uid, amt, st) in enumerate(
|
| 58 |
+
zip(order_user_ids, amounts, statuses, strict=False), start=1
|
| 59 |
+
)
|
| 60 |
+
],
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
baseline_sql = (
|
| 64 |
+
"SELECT u.id, u.tier, "
|
| 65 |
+
"(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id AND o.status = 'fulfilled') "
|
| 66 |
+
" AS fulfilled_orders "
|
| 67 |
+
"FROM users u "
|
| 68 |
+
"ORDER BY u.id"
|
| 69 |
+
)
|
| 70 |
+
gt_sql_predrift = (
|
| 71 |
+
"SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders "
|
| 72 |
+
"FROM users u LEFT JOIN ("
|
| 73 |
+
" SELECT user_id, COUNT(*) AS n FROM orders "
|
| 74 |
+
" WHERE status = 'fulfilled' GROUP BY user_id"
|
| 75 |
+
") c ON c.user_id = u.id "
|
| 76 |
+
"ORDER BY u.id"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
synopsis = (
|
| 80 |
+
"users(id PK, signup_month, tier); orders(id PK, user_id→users.id, amount, status). "
|
| 81 |
+
"Baseline scans orders once per user via a correlated subquery."
|
| 82 |
+
)
|
| 83 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
SPEC = ScenarioSpec(
|
| 87 |
+
scenario_id="01_correlated_subquery",
|
| 88 |
+
family="ecommerce",
|
| 89 |
+
tags=frozenset({"correlated_subquery", "projection_subquery", "ecommerce"}),
|
| 90 |
+
drift_config=None,
|
| 91 |
+
builder=_build,
|
| 92 |
+
base_scale=800,
|
| 93 |
+
)
|
scenarios/scenario_02_select_star_join.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 02 — SELECT * over a join anti-pattern.
|
| 2 |
+
|
| 3 |
+
Baseline pulls every column from three joined tables when only two columns
|
| 4 |
+
are needed. The rewrite projects exactly the requested columns.
|
| 5 |
+
|
| 6 |
+
Schema: products(id, sku, name, description, price_cents, …), orders(id, user_id, created_at),
|
| 7 |
+
order_items(order_id, product_id, qty, unit_price_cents).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import (
|
| 15 |
+
categorical_choices,
|
| 16 |
+
lognormal_amounts,
|
| 17 |
+
seeded_rng,
|
| 18 |
+
unique_names,
|
| 19 |
+
zipfian_choices,
|
| 20 |
+
)
|
| 21 |
+
from .base import BuilderResult, ScenarioSpec
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 25 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 26 |
+
n_products = max(100, scale // 4)
|
| 27 |
+
n_orders = scale
|
| 28 |
+
n_items = scale * 3
|
| 29 |
+
n_users = max(50, scale // 8)
|
| 30 |
+
|
| 31 |
+
product_ids = list(range(1, n_products + 1))
|
| 32 |
+
skus = unique_names(rng, n_products, prefix="sku")
|
| 33 |
+
names = unique_names(rng, n_products, prefix="p")
|
| 34 |
+
descriptions = [f"Long marketing copy for {n}" * 6 for n in names] # wide col
|
| 35 |
+
prices = [int(x * 100) for x in lognormal_amounts(rng, n_products, mu=3.0, sigma=1.1)]
|
| 36 |
+
categories = categorical_choices(
|
| 37 |
+
rng, ["books", "electronics", "apparel", "grocery", "home"], n_products
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
user_ids = list(range(1, n_users + 1))
|
| 41 |
+
order_user_ids = zipfian_choices(rng, user_ids, n_orders)
|
| 42 |
+
order_created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)]
|
| 43 |
+
|
| 44 |
+
item_order_ids = rng.choices(list(range(1, n_orders + 1)), k=n_items)
|
| 45 |
+
item_product_ids = zipfian_choices(rng, product_ids, n_items)
|
| 46 |
+
item_qty = rng.choices([1, 1, 1, 2, 2, 3, 4, 5], k=n_items)
|
| 47 |
+
item_unit_price = [prices[pid - 1] for pid in item_product_ids]
|
| 48 |
+
|
| 49 |
+
conn = duckdb.connect(":memory:")
|
| 50 |
+
conn.execute(
|
| 51 |
+
"CREATE TABLE products("
|
| 52 |
+
" id BIGINT PRIMARY KEY, sku VARCHAR, name VARCHAR, description VARCHAR,"
|
| 53 |
+
" price_cents BIGINT, category VARCHAR);"
|
| 54 |
+
)
|
| 55 |
+
conn.execute(
|
| 56 |
+
"CREATE TABLE orders( id BIGINT PRIMARY KEY, user_id BIGINT, created_at_epoch_s BIGINT);"
|
| 57 |
+
)
|
| 58 |
+
conn.execute(
|
| 59 |
+
"CREATE TABLE order_items("
|
| 60 |
+
" order_id BIGINT, product_id BIGINT, qty INTEGER, unit_price_cents BIGINT);"
|
| 61 |
+
)
|
| 62 |
+
conn.executemany(
|
| 63 |
+
"INSERT INTO products VALUES (?, ?, ?, ?, ?, ?)",
|
| 64 |
+
list(zip(product_ids, skus, names, descriptions, prices, categories, strict=False)),
|
| 65 |
+
)
|
| 66 |
+
conn.executemany(
|
| 67 |
+
"INSERT INTO orders VALUES (?, ?, ?)",
|
| 68 |
+
[
|
| 69 |
+
(oid, uid, ts)
|
| 70 |
+
for oid, (uid, ts) in enumerate(zip(order_user_ids, order_created, strict=False), 1)
|
| 71 |
+
],
|
| 72 |
+
)
|
| 73 |
+
conn.executemany(
|
| 74 |
+
"INSERT INTO order_items VALUES (?, ?, ?, ?)",
|
| 75 |
+
list(zip(item_order_ids, item_product_ids, item_qty, item_unit_price, strict=False)),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Anti-pattern: SELECT * inside a subquery that wraps the real join, then
|
| 79 |
+
# the outer query projects only a handful of columns. DuckDB materializes
|
| 80 |
+
# every column of the subquery before the projection can prune it.
|
| 81 |
+
baseline_sql = (
|
| 82 |
+
"SELECT t.order_id, t.name, t.qty "
|
| 83 |
+
"FROM ("
|
| 84 |
+
" SELECT * FROM order_items oi "
|
| 85 |
+
" JOIN products p ON p.id = oi.product_id "
|
| 86 |
+
" JOIN orders o ON o.id = oi.order_id "
|
| 87 |
+
" WHERE p.category = 'books' AND oi.qty >= 2"
|
| 88 |
+
") t "
|
| 89 |
+
"ORDER BY t.order_id, t.name"
|
| 90 |
+
)
|
| 91 |
+
gt_sql_predrift = (
|
| 92 |
+
"SELECT oi.order_id, p.name, oi.qty "
|
| 93 |
+
"FROM order_items oi "
|
| 94 |
+
"JOIN products p ON p.id = oi.product_id "
|
| 95 |
+
"JOIN orders o ON o.id = oi.order_id "
|
| 96 |
+
"WHERE p.category = 'books' AND oi.qty >= 2 "
|
| 97 |
+
"ORDER BY oi.order_id, p.name"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
synopsis = (
|
| 101 |
+
"products(id PK, sku, name, description, price_cents, category); "
|
| 102 |
+
"orders(id PK, user_id, created_at_epoch_s); "
|
| 103 |
+
"order_items(order_id, product_id, qty, unit_price_cents). "
|
| 104 |
+
"Baseline wraps a three-way join with SELECT * inside a subquery; "
|
| 105 |
+
"only (order_id, product name, qty) are needed downstream."
|
| 106 |
+
)
|
| 107 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
SPEC = ScenarioSpec(
|
| 111 |
+
scenario_id="02_select_star_join",
|
| 112 |
+
family="ecommerce",
|
| 113 |
+
tags=frozenset({"select_star", "over_projection", "join", "ecommerce"}),
|
| 114 |
+
drift_config=None,
|
| 115 |
+
builder=_build,
|
| 116 |
+
# Three-way join with SELECT * is already well-optimized by DuckDB;
|
| 117 |
+
# a larger base keeps first-try baseline above the 1 ms floor.
|
| 118 |
+
base_scale=1_500,
|
| 119 |
+
)
|
scenarios/scenario_03_cartesian_join.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 03 — accidental cartesian join.
|
| 2 |
+
|
| 3 |
+
Baseline joins events × tenants without an equijoin clause and relies on a
|
| 4 |
+
post-filter `WHERE` to restore tenant scoping. The optimizer cannot push
|
| 5 |
+
the filter into the join, so it materializes the full |events|·|tenants|
|
| 6 |
+
cross-product first. The rewrite moves the tenant key into an ON clause.
|
| 7 |
+
|
| 8 |
+
Schema: tenants(id, tier), events(id, tenant_id, kind, severity).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import duckdb
|
| 14 |
+
|
| 15 |
+
from ._fixtures import categorical_choices, seeded_rng, zipfian_choices
|
| 16 |
+
from .base import BuilderResult, ScenarioSpec
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 20 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 21 |
+
n_tenants = max(20, scale // 40)
|
| 22 |
+
n_events = scale * 4
|
| 23 |
+
|
| 24 |
+
tenant_ids = list(range(1, n_tenants + 1))
|
| 25 |
+
tiers = categorical_choices(
|
| 26 |
+
rng, ["free", "pro", "enterprise"], n_tenants, weights=[0.6, 0.3, 0.1]
|
| 27 |
+
)
|
| 28 |
+
event_tenant_ids = zipfian_choices(rng, tenant_ids, n_events)
|
| 29 |
+
kinds = categorical_choices(
|
| 30 |
+
rng,
|
| 31 |
+
["login", "logout", "action", "error"],
|
| 32 |
+
n_events,
|
| 33 |
+
weights=[0.35, 0.3, 0.3, 0.05],
|
| 34 |
+
)
|
| 35 |
+
severities = categorical_choices(
|
| 36 |
+
rng,
|
| 37 |
+
["info", "warn", "error", "critical"],
|
| 38 |
+
n_events,
|
| 39 |
+
weights=[0.7, 0.2, 0.08, 0.02],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
conn = duckdb.connect(":memory:")
|
| 43 |
+
conn.execute("CREATE TABLE tenants(id BIGINT PRIMARY KEY, tier VARCHAR);")
|
| 44 |
+
conn.execute(
|
| 45 |
+
"CREATE TABLE events("
|
| 46 |
+
" id BIGINT PRIMARY KEY, tenant_id BIGINT, kind VARCHAR, severity VARCHAR);"
|
| 47 |
+
)
|
| 48 |
+
conn.executemany(
|
| 49 |
+
"INSERT INTO tenants VALUES (?, ?)", list(zip(tenant_ids, tiers, strict=False))
|
| 50 |
+
)
|
| 51 |
+
conn.executemany(
|
| 52 |
+
"INSERT INTO events VALUES (?, ?, ?, ?)",
|
| 53 |
+
[
|
| 54 |
+
(i, tid, k, s)
|
| 55 |
+
for i, (tid, k, s) in enumerate(
|
| 56 |
+
zip(event_tenant_ids, kinds, severities, strict=False), start=1
|
| 57 |
+
)
|
| 58 |
+
],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
baseline_sql = (
|
| 62 |
+
"SELECT t.tier, COUNT(*) AS n "
|
| 63 |
+
"FROM events e, tenants t "
|
| 64 |
+
"WHERE t.id = e.tenant_id + 0 " # defeat optimizer recognition
|
| 65 |
+
"AND e.severity IN ('error', 'critical') "
|
| 66 |
+
"GROUP BY t.tier ORDER BY t.tier"
|
| 67 |
+
)
|
| 68 |
+
gt_sql_predrift = (
|
| 69 |
+
"SELECT t.tier, COUNT(*) AS n "
|
| 70 |
+
"FROM events e JOIN tenants t ON t.id = e.tenant_id "
|
| 71 |
+
"WHERE e.severity IN ('error', 'critical') "
|
| 72 |
+
"GROUP BY t.tier ORDER BY t.tier"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
synopsis = (
|
| 76 |
+
"tenants(id PK, tier); events(id PK, tenant_id→tenants.id, kind, severity). "
|
| 77 |
+
"Baseline relies on a WHERE-clause equijoin obscured by arithmetic, "
|
| 78 |
+
"forcing a cartesian materialization."
|
| 79 |
+
)
|
| 80 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
SPEC = ScenarioSpec(
|
| 84 |
+
scenario_id="03_cartesian_join",
|
| 85 |
+
family="events",
|
| 86 |
+
tags=frozenset({"cartesian", "missing_join_condition", "events"}),
|
| 87 |
+
drift_config=None,
|
| 88 |
+
builder=_build,
|
| 89 |
+
# Cartesian materialization cost grows with |events|·|tenants| so
|
| 90 |
+
# larger base_scale keeps first-try baseline above the 1 ms floor;
|
| 91 |
+
# the reroll loop doubles from here if needed.
|
| 92 |
+
base_scale=1_200,
|
| 93 |
+
)
|
scenarios/scenario_04_distinct_groupby.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 04 — redundant DISTINCT on top of GROUP BY.
|
| 2 |
+
|
| 3 |
+
Baseline applies DISTINCT after already grouping, forcing a second
|
| 4 |
+
deduplication over a result whose GROUP BY keys are already unique.
|
| 5 |
+
Rewrite drops the DISTINCT entirely.
|
| 6 |
+
|
| 7 |
+
Schema: pageviews(id, session_id, path, ts_epoch_s).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices
|
| 15 |
+
from .base import BuilderResult, ScenarioSpec
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 19 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 20 |
+
n_sessions = max(200, scale // 4)
|
| 21 |
+
n_views = scale * 8
|
| 22 |
+
|
| 23 |
+
session_ids = unique_names(rng, n_sessions, prefix="sess")
|
| 24 |
+
paths = categorical_choices(
|
| 25 |
+
rng,
|
| 26 |
+
[f"/path/{p}" for p in ["home", "about", "product", "cart", "checkout", "help"]],
|
| 27 |
+
n_views,
|
| 28 |
+
)
|
| 29 |
+
view_sessions = zipfian_choices(rng, list(range(n_sessions)), n_views)
|
| 30 |
+
view_session_ids = [session_ids[i] for i in view_sessions]
|
| 31 |
+
view_ts = [1_700_000_000 + rng.randrange(30 * 86_400) for _ in range(n_views)]
|
| 32 |
+
|
| 33 |
+
conn = duckdb.connect(":memory:")
|
| 34 |
+
conn.execute(
|
| 35 |
+
"CREATE TABLE pageviews("
|
| 36 |
+
" id BIGINT PRIMARY KEY, session_id VARCHAR, path VARCHAR, ts_epoch_s BIGINT);"
|
| 37 |
+
)
|
| 38 |
+
conn.executemany(
|
| 39 |
+
"INSERT INTO pageviews VALUES (?, ?, ?, ?)",
|
| 40 |
+
[
|
| 41 |
+
(i, sid, p, ts)
|
| 42 |
+
for i, (sid, p, ts) in enumerate(
|
| 43 |
+
zip(view_session_ids, paths, view_ts, strict=False), start=1
|
| 44 |
+
)
|
| 45 |
+
],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
baseline_sql = (
|
| 49 |
+
"SELECT DISTINCT session_id, path, COUNT(*) AS hits "
|
| 50 |
+
"FROM pageviews GROUP BY session_id, path "
|
| 51 |
+
"ORDER BY session_id, path"
|
| 52 |
+
)
|
| 53 |
+
gt_sql_predrift = (
|
| 54 |
+
"SELECT session_id, path, COUNT(*) AS hits "
|
| 55 |
+
"FROM pageviews GROUP BY session_id, path "
|
| 56 |
+
"ORDER BY session_id, path"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
synopsis = (
|
| 60 |
+
"pageviews(id PK, session_id, path, ts_epoch_s). "
|
| 61 |
+
"Baseline applies redundant DISTINCT on a GROUP BY whose keys are already unique."
|
| 62 |
+
)
|
| 63 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
SPEC = ScenarioSpec(
|
| 67 |
+
scenario_id="04_distinct_groupby",
|
| 68 |
+
family="saas_logs",
|
| 69 |
+
tags=frozenset({"distinct", "redundant_distinct", "group_by", "saas_logs"}),
|
| 70 |
+
drift_config=None,
|
| 71 |
+
builder=_build,
|
| 72 |
+
base_scale=500,
|
| 73 |
+
)
|
scenarios/scenario_05_nested_subquery.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 05 — deeply nested subquery anti-pattern.
|
| 2 |
+
|
| 3 |
+
Baseline stacks two layers of `IN (SELECT ...)` where a single join + GROUP
|
| 4 |
+
BY expresses the same intent. The rewrite flattens the nesting.
|
| 5 |
+
|
| 6 |
+
Schema: articles(id, author_id, published_at), comments(id, article_id, author_id),
|
| 7 |
+
authors(id, display_name, is_verified).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices
|
| 15 |
+
from .base import BuilderResult, ScenarioSpec
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 19 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 20 |
+
n_authors = max(50, scale // 8)
|
| 21 |
+
n_articles = scale
|
| 22 |
+
n_comments = scale * 4
|
| 23 |
+
|
| 24 |
+
author_ids = list(range(1, n_authors + 1))
|
| 25 |
+
display_names = unique_names(rng, n_authors, prefix="author")
|
| 26 |
+
verified = rng.choices([True, False], weights=[0.2, 0.8], k=n_authors)
|
| 27 |
+
|
| 28 |
+
article_ids = list(range(1, n_articles + 1))
|
| 29 |
+
article_authors = zipfian_choices(rng, author_ids, n_articles)
|
| 30 |
+
article_published = [1_700_000_000 + rng.randrange(120 * 86_400) for _ in range(n_articles)]
|
| 31 |
+
article_statuses = categorical_choices(
|
| 32 |
+
rng, ["draft", "published", "archived"], n_articles, weights=[0.2, 0.7, 0.1]
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
comment_ids = list(range(1, n_comments + 1))
|
| 36 |
+
comment_articles = rng.choices(article_ids, k=n_comments)
|
| 37 |
+
comment_authors = zipfian_choices(rng, author_ids, n_comments)
|
| 38 |
+
|
| 39 |
+
conn = duckdb.connect(":memory:")
|
| 40 |
+
conn.execute(
|
| 41 |
+
"CREATE TABLE authors( id BIGINT PRIMARY KEY, display_name VARCHAR, is_verified BOOLEAN);"
|
| 42 |
+
)
|
| 43 |
+
conn.execute(
|
| 44 |
+
"CREATE TABLE articles("
|
| 45 |
+
" id BIGINT PRIMARY KEY, author_id BIGINT, published_at_epoch_s BIGINT,"
|
| 46 |
+
" status VARCHAR);"
|
| 47 |
+
)
|
| 48 |
+
conn.execute(
|
| 49 |
+
"CREATE TABLE comments( id BIGINT PRIMARY KEY, article_id BIGINT, author_id BIGINT);"
|
| 50 |
+
)
|
| 51 |
+
conn.executemany(
|
| 52 |
+
"INSERT INTO authors VALUES (?, ?, ?)",
|
| 53 |
+
list(zip(author_ids, display_names, verified, strict=False)),
|
| 54 |
+
)
|
| 55 |
+
conn.executemany(
|
| 56 |
+
"INSERT INTO articles VALUES (?, ?, ?, ?)",
|
| 57 |
+
list(zip(article_ids, article_authors, article_published, article_statuses, strict=False)),
|
| 58 |
+
)
|
| 59 |
+
conn.executemany(
|
| 60 |
+
"INSERT INTO comments VALUES (?, ?, ?)",
|
| 61 |
+
list(zip(comment_ids, comment_articles, comment_authors, strict=False)),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
baseline_sql = (
|
| 65 |
+
"SELECT display_name "
|
| 66 |
+
"FROM authors "
|
| 67 |
+
"WHERE id IN ("
|
| 68 |
+
" SELECT author_id FROM comments "
|
| 69 |
+
" WHERE article_id IN ("
|
| 70 |
+
" SELECT id FROM articles WHERE status = 'published'"
|
| 71 |
+
" )"
|
| 72 |
+
") "
|
| 73 |
+
"ORDER BY display_name"
|
| 74 |
+
)
|
| 75 |
+
gt_sql_predrift = (
|
| 76 |
+
"SELECT DISTINCT a.display_name "
|
| 77 |
+
"FROM authors a "
|
| 78 |
+
"JOIN comments c ON c.author_id = a.id "
|
| 79 |
+
"JOIN articles ar ON ar.id = c.article_id "
|
| 80 |
+
"WHERE ar.status = 'published' "
|
| 81 |
+
"ORDER BY a.display_name"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
synopsis = (
|
| 85 |
+
"authors(id PK, display_name, is_verified); "
|
| 86 |
+
"articles(id PK, author_id→authors.id, published_at_epoch_s, status); "
|
| 87 |
+
"comments(id PK, article_id→articles.id, author_id→authors.id). "
|
| 88 |
+
"Baseline chains two IN-subqueries where one JOIN + DISTINCT suffices."
|
| 89 |
+
)
|
| 90 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
SPEC = ScenarioSpec(
|
| 94 |
+
scenario_id="05_nested_subquery",
|
| 95 |
+
family="cms",
|
| 96 |
+
tags=frozenset({"nested_subquery", "in_subquery", "cms"}),
|
| 97 |
+
drift_config=None,
|
| 98 |
+
builder=_build,
|
| 99 |
+
base_scale=1_500,
|
| 100 |
+
)
|
scenarios/scenario_06_having_as_where.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 06 — HAVING used as WHERE.
|
| 2 |
+
|
| 3 |
+
Baseline filters on a grouping key inside HAVING, which forces the engine
|
| 4 |
+
to group first and filter after. The rewrite moves the non-aggregate
|
| 5 |
+
predicate into WHERE.
|
| 6 |
+
|
| 7 |
+
Schema: orders(id, tenant_id, user_id, amount_cents, status, created_at_epoch_s).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import categorical_choices, lognormal_amounts, seeded_rng, zipfian_choices
|
| 15 |
+
from .base import BuilderResult, ScenarioSpec
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 19 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 20 |
+
n_tenants = max(40, scale // 40)
|
| 21 |
+
n_users = max(200, scale // 4)
|
| 22 |
+
n_orders = scale * 10
|
| 23 |
+
|
| 24 |
+
tenant_ids = list(range(1, n_tenants + 1))
|
| 25 |
+
user_ids = list(range(1, n_users + 1))
|
| 26 |
+
user_tenants = rng.choices(tenant_ids, k=n_users)
|
| 27 |
+
|
| 28 |
+
order_user = zipfian_choices(rng, user_ids, n_orders)
|
| 29 |
+
order_tenant = [user_tenants[u - 1] for u in order_user]
|
| 30 |
+
order_amount = [int(x * 100) for x in lognormal_amounts(rng, n_orders, mu=3.5, sigma=0.8)]
|
| 31 |
+
statuses = categorical_choices(
|
| 32 |
+
rng,
|
| 33 |
+
["placed", "fulfilled", "refunded", "cancelled"],
|
| 34 |
+
n_orders,
|
| 35 |
+
weights=[0.55, 0.3, 0.1, 0.05],
|
| 36 |
+
)
|
| 37 |
+
created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)]
|
| 38 |
+
|
| 39 |
+
conn = duckdb.connect(":memory:")
|
| 40 |
+
conn.execute(
|
| 41 |
+
"CREATE TABLE orders("
|
| 42 |
+
" id BIGINT PRIMARY KEY, tenant_id BIGINT, user_id BIGINT,"
|
| 43 |
+
" amount_cents BIGINT, status VARCHAR, created_at_epoch_s BIGINT);"
|
| 44 |
+
)
|
| 45 |
+
conn.executemany(
|
| 46 |
+
"INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)",
|
| 47 |
+
[
|
| 48 |
+
(i, t, u, a, s, ts)
|
| 49 |
+
for i, (t, u, a, s, ts) in enumerate(
|
| 50 |
+
zip(order_tenant, order_user, order_amount, statuses, created, strict=False),
|
| 51 |
+
start=1,
|
| 52 |
+
)
|
| 53 |
+
],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Anti-pattern: filter on a row-level column (status) that IS in GROUP BY
|
| 57 |
+
# inside HAVING, forcing the engine to build groups for every status value
|
| 58 |
+
# before discarding most of them. The rewrite moves the filter into WHERE
|
| 59 |
+
# so aggregation only runs over rows we actually care about.
|
| 60 |
+
baseline_sql = (
|
| 61 |
+
"SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
|
| 62 |
+
"FROM orders "
|
| 63 |
+
"GROUP BY tenant_id, status "
|
| 64 |
+
"HAVING status = 'fulfilled' "
|
| 65 |
+
" AND SUM(amount_cents) >= 100000 "
|
| 66 |
+
"ORDER BY tenant_id"
|
| 67 |
+
)
|
| 68 |
+
gt_sql_predrift = (
|
| 69 |
+
"SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
|
| 70 |
+
"FROM orders "
|
| 71 |
+
"WHERE status = 'fulfilled' "
|
| 72 |
+
"GROUP BY tenant_id, status "
|
| 73 |
+
"HAVING SUM(amount_cents) >= 100000 "
|
| 74 |
+
"ORDER BY tenant_id"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
synopsis = (
|
| 78 |
+
"orders(id PK, tenant_id, user_id, amount_cents, status, created_at_epoch_s). "
|
| 79 |
+
"Baseline filters `status` inside HAVING, forcing aggregation over "
|
| 80 |
+
"every status group before discarding all but 'fulfilled'."
|
| 81 |
+
)
|
| 82 |
+
return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
SPEC = ScenarioSpec(
|
| 86 |
+
scenario_id="06_having_as_where",
|
| 87 |
+
family="ecommerce",
|
| 88 |
+
tags=frozenset({"having_as_where", "aggregate_filter", "ecommerce"}),
|
| 89 |
+
drift_config=None,
|
| 90 |
+
builder=_build,
|
| 91 |
+
base_scale=1_500,
|
| 92 |
+
)
|
scenarios/scenario_07_drift_column_rename.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 07 — drift: column rename (orders.user_id → orders.account_id).
|
| 2 |
+
|
| 3 |
+
Baseline groups orders by `user_id`. When the drift fires, the column is
|
| 4 |
+
renamed; the agent must rewrite its query against `account_id`. Row
|
| 5 |
+
semantics are unchanged — the post-drift hash equals the pre-drift hash
|
| 6 |
+
because the only thing that changed is the column label.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import duckdb
|
| 12 |
+
|
| 13 |
+
from ._fixtures import lognormal_amounts, seeded_rng, zipfian_choices
|
| 14 |
+
from .base import BuilderResult, DriftConfig, ScenarioSpec
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 18 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 19 |
+
n_users = max(200, scale // 2)
|
| 20 |
+
n_orders = scale * 4
|
| 21 |
+
|
| 22 |
+
user_ids = list(range(1, n_users + 1))
|
| 23 |
+
order_users = zipfian_choices(rng, user_ids, n_orders)
|
| 24 |
+
amounts = lognormal_amounts(rng, n_orders, mu=3.0, sigma=0.8)
|
| 25 |
+
|
| 26 |
+
conn = duckdb.connect(":memory:")
|
| 27 |
+
conn.execute("CREATE TABLE users(id BIGINT PRIMARY KEY, signup_month INTEGER);")
|
| 28 |
+
conn.execute("CREATE TABLE orders(id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE);")
|
| 29 |
+
conn.executemany(
|
| 30 |
+
"INSERT INTO users VALUES (?, ?)",
|
| 31 |
+
[(uid, (uid % 12) + 1) for uid in user_ids],
|
| 32 |
+
)
|
| 33 |
+
conn.executemany(
|
| 34 |
+
"INSERT INTO orders VALUES (?, ?, ?)",
|
| 35 |
+
[(i, uid, amt) for i, (uid, amt) in enumerate(zip(order_users, amounts, strict=False), 1)],
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
baseline_sql = (
|
| 39 |
+
"SELECT user_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total "
|
| 40 |
+
"FROM orders GROUP BY user_id ORDER BY user_id"
|
| 41 |
+
)
|
| 42 |
+
gt_sql_predrift = baseline_sql # static part — baseline IS correct pre-drift
|
| 43 |
+
gt_sql_postdrift = (
|
| 44 |
+
"SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total "
|
| 45 |
+
"FROM orders GROUP BY account_id ORDER BY account_id"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
synopsis = (
|
| 49 |
+
"users(id PK, signup_month); orders(id PK, user_id→users.id, amount). "
|
| 50 |
+
"Under drift, orders.user_id is renamed to orders.account_id."
|
| 51 |
+
)
|
| 52 |
+
return (
|
| 53 |
+
conn,
|
| 54 |
+
baseline_sql,
|
| 55 |
+
gt_sql_predrift,
|
| 56 |
+
gt_sql_postdrift,
|
| 57 |
+
synopsis,
|
| 58 |
+
frozenset({"account_id"}),
|
| 59 |
+
frozenset({"user_id"}),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
SPEC = ScenarioSpec(
|
| 64 |
+
scenario_id="07_drift_column_rename",
|
| 65 |
+
family="ecommerce",
|
| 66 |
+
tags=frozenset({"drift", "column_rename", "ecommerce"}),
|
| 67 |
+
drift_config=DriftConfig(
|
| 68 |
+
kind="column_rename",
|
| 69 |
+
payload={"table": "orders", "old": "user_id", "new": "account_id"},
|
| 70 |
+
),
|
| 71 |
+
builder=_build,
|
| 72 |
+
base_scale=2_000,
|
| 73 |
+
)
|
scenarios/scenario_08_drift_date_format.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 08 — drift: date format (events.ts iso_string → epoch_ms).
|
| 2 |
+
|
| 3 |
+
Baseline filters events inside a specific UTC day via ISO string
|
| 4 |
+
comparisons. When the drift fires, the `ts` column becomes a BIGINT of
|
| 5 |
+
epoch-ms; the agent must rewrite comparisons against the numeric value.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from datetime import UTC, datetime, timedelta
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
|
| 14 |
+
from ._fixtures import categorical_choices, seeded_rng
|
| 15 |
+
from .base import BuilderResult, DriftConfig, ScenarioSpec
|
| 16 |
+
|
| 17 |
+
# Fixed anchor day — comparisons are deterministic across seeds.
|
| 18 |
+
_ANCHOR = datetime(2026, 4, 21, tzinfo=UTC)
|
| 19 |
+
_ANCHOR_NEXT = _ANCHOR + timedelta(days=1)
|
| 20 |
+
_ANCHOR_ISO = _ANCHOR.isoformat().replace("+00:00", "Z")
|
| 21 |
+
_ANCHOR_NEXT_ISO = _ANCHOR_NEXT.isoformat().replace("+00:00", "Z")
|
| 22 |
+
_ANCHOR_MS = int(_ANCHOR.timestamp() * 1000)
|
| 23 |
+
_ANCHOR_NEXT_MS = int(_ANCHOR_NEXT.timestamp() * 1000)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 27 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 28 |
+
n_events = scale * 8
|
| 29 |
+
|
| 30 |
+
window_start = _ANCHOR - timedelta(days=3)
|
| 31 |
+
window_span_s = 7 * 86_400
|
| 32 |
+
event_dts = [
|
| 33 |
+
(window_start + timedelta(seconds=rng.randrange(window_span_s))) for _ in range(n_events)
|
| 34 |
+
]
|
| 35 |
+
event_iso = [dt.isoformat().replace("+00:00", "Z") for dt in event_dts]
|
| 36 |
+
kinds = categorical_choices(
|
| 37 |
+
rng, ["login", "action", "error"], n_events, weights=[0.6, 0.35, 0.05]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
conn = duckdb.connect(":memory:")
|
| 41 |
+
conn.execute("CREATE TABLE events( id BIGINT PRIMARY KEY, ts VARCHAR, kind VARCHAR);")
|
| 42 |
+
conn.executemany(
|
| 43 |
+
"INSERT INTO events VALUES (?, ?, ?)",
|
| 44 |
+
[(i, iso, k) for i, (iso, k) in enumerate(zip(event_iso, kinds, strict=False), 1)],
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
baseline_sql = (
|
| 48 |
+
f"SELECT kind, COUNT(*) AS n FROM events "
|
| 49 |
+
f"WHERE ts >= '{_ANCHOR_ISO}' AND ts < '{_ANCHOR_NEXT_ISO}' "
|
| 50 |
+
f"GROUP BY kind ORDER BY kind"
|
| 51 |
+
)
|
| 52 |
+
gt_sql_predrift = baseline_sql
|
| 53 |
+
gt_sql_postdrift = (
|
| 54 |
+
f"SELECT kind, COUNT(*) AS n FROM events "
|
| 55 |
+
f"WHERE ts >= {_ANCHOR_MS} AND ts < {_ANCHOR_NEXT_MS} "
|
| 56 |
+
f"GROUP BY kind ORDER BY kind"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
synopsis = (
|
| 60 |
+
"events(id PK, ts VARCHAR(ISO-8601 UTC), kind). Under drift, `ts` "
|
| 61 |
+
f"becomes BIGINT epoch-ms. Filter window is {_ANCHOR_ISO} – "
|
| 62 |
+
f"{_ANCHOR_NEXT_ISO} (i.e. epoch-ms in "
|
| 63 |
+
f"[{_ANCHOR_MS}, {_ANCHOR_NEXT_MS}))."
|
| 64 |
+
)
|
| 65 |
+
# Date-format drift keeps the ``ts`` identifier; what changes is the
|
| 66 |
+
# literal shape (ISO string → epoch-ms integer). The rubric
|
| 67 |
+
# therefore can't distinguish "adapted" from "not adapted" on
|
| 68 |
+
# identifiers alone, so we expose the ISO anchor strings as the
|
| 69 |
+
# pre-drift distinctive set and leave postdrift empty — the rubric
|
| 70 |
+
# treats absence-of-predrift-markers as adaptation whenever
|
| 71 |
+
# ``postdrift_identifiers`` is empty (DriftAdapt rubric case).
|
| 72 |
+
return (
|
| 73 |
+
conn,
|
| 74 |
+
baseline_sql,
|
| 75 |
+
gt_sql_predrift,
|
| 76 |
+
gt_sql_postdrift,
|
| 77 |
+
synopsis,
|
| 78 |
+
frozenset(),
|
| 79 |
+
frozenset({_ANCHOR_ISO, _ANCHOR_NEXT_ISO}),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
SPEC = ScenarioSpec(
|
| 84 |
+
scenario_id="08_drift_date_format",
|
| 85 |
+
family="events",
|
| 86 |
+
tags=frozenset({"drift", "date_format", "iso_to_epoch", "events"}),
|
| 87 |
+
drift_config=DriftConfig(
|
| 88 |
+
kind="date_format",
|
| 89 |
+
payload={"table": "events", "col": "ts", "from": "iso_string", "to": "epoch_ms"},
|
| 90 |
+
),
|
| 91 |
+
builder=_build,
|
| 92 |
+
base_scale=500,
|
| 93 |
+
)
|
scenarios/scenario_09_drift_enum_rule.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 09 — drift: enum split ('active' → 'ACTIVE' / 'ACTIVE_V2').
|
| 2 |
+
|
| 3 |
+
Baseline counts tenants whose status is 'active'. After drift, every
|
| 4 |
+
previously-'active' row has been relabelled to either 'ACTIVE' or
|
| 5 |
+
'ACTIVE_V2' (deterministic round-robin). The agent must filter on the
|
| 6 |
+
union of the new values to recover the business-equivalent count.
|
| 7 |
+
|
| 8 |
+
Note: unlike 07 and 10, the post-drift data changed, so the post-drift
|
| 9 |
+
ground-truth hash is computed against the post-drift rows — the agent's
|
| 10 |
+
result set now reflects the new status values.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import duckdb
|
| 16 |
+
|
| 17 |
+
from ._fixtures import categorical_choices, seeded_rng
|
| 18 |
+
from .base import BuilderResult, DriftConfig, ScenarioSpec
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 22 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 23 |
+
n_tenants = max(400, scale)
|
| 24 |
+
|
| 25 |
+
statuses = categorical_choices(
|
| 26 |
+
rng,
|
| 27 |
+
["active", "trial", "suspended", "churned"],
|
| 28 |
+
n_tenants,
|
| 29 |
+
weights=[0.55, 0.2, 0.15, 0.1],
|
| 30 |
+
)
|
| 31 |
+
tiers = categorical_choices(
|
| 32 |
+
rng, ["free", "pro", "business"], n_tenants, weights=[0.6, 0.3, 0.1]
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
conn = duckdb.connect(":memory:")
|
| 36 |
+
conn.execute("CREATE TABLE tenants( id BIGINT PRIMARY KEY, status VARCHAR, tier VARCHAR);")
|
| 37 |
+
conn.executemany(
|
| 38 |
+
"INSERT INTO tenants VALUES (?, ?, ?)",
|
| 39 |
+
[(i, s, t) for i, (s, t) in enumerate(zip(statuses, tiers, strict=False), 1)],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Pre-drift: all 'active' tenants; post-drift: union of the new labels.
|
| 43 |
+
baseline_sql = (
|
| 44 |
+
"SELECT tier, COUNT(*) AS n "
|
| 45 |
+
"FROM tenants WHERE status = 'active' "
|
| 46 |
+
"GROUP BY tier ORDER BY tier"
|
| 47 |
+
)
|
| 48 |
+
gt_sql_predrift = baseline_sql
|
| 49 |
+
gt_sql_postdrift = (
|
| 50 |
+
"SELECT tier, COUNT(*) AS n "
|
| 51 |
+
"FROM tenants WHERE status IN ('ACTIVE', 'ACTIVE_V2') "
|
| 52 |
+
"GROUP BY tier ORDER BY tier"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
synopsis = (
|
| 56 |
+
"tenants(id PK, status, tier). Under drift, status='active' is split "
|
| 57 |
+
"into 'ACTIVE' and 'ACTIVE_V2'; 'trial'/'suspended'/'churned' are unchanged."
|
| 58 |
+
)
|
| 59 |
+
return (
|
| 60 |
+
conn,
|
| 61 |
+
baseline_sql,
|
| 62 |
+
gt_sql_predrift,
|
| 63 |
+
gt_sql_postdrift,
|
| 64 |
+
synopsis,
|
| 65 |
+
frozenset({"ACTIVE", "ACTIVE_V2"}),
|
| 66 |
+
frozenset({"active"}),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
SPEC = ScenarioSpec(
|
| 71 |
+
scenario_id="09_drift_enum_rule",
|
| 72 |
+
family="multitenant",
|
| 73 |
+
tags=frozenset({"drift", "enum_rule", "business_rule", "multitenant"}),
|
| 74 |
+
drift_config=DriftConfig(
|
| 75 |
+
kind="enum_rule",
|
| 76 |
+
payload={
|
| 77 |
+
"table": "tenants",
|
| 78 |
+
"col": "status",
|
| 79 |
+
"old_value": "active",
|
| 80 |
+
"new_values": ["ACTIVE", "ACTIVE_V2"],
|
| 81 |
+
},
|
| 82 |
+
),
|
| 83 |
+
builder=_build,
|
| 84 |
+
base_scale=600,
|
| 85 |
+
)
|
scenarios/scenario_10_drift_field_deprecation.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario 10 — drift: inline string col replaced by FK lookup.
|
| 2 |
+
|
| 3 |
+
Baseline groups posts by the inline `author_name` string. Under drift, a
|
| 4 |
+
`users(id, full_name)` lookup is created, `posts.author_name` is dropped,
|
| 5 |
+
and `posts.users_id` is added with a backfilled FK. The correct rewrite
|
| 6 |
+
joins through `users` and groups by `full_name`.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import duckdb
|
| 12 |
+
|
| 13 |
+
from ._fixtures import seeded_rng, unique_names, zipfian_choices
|
| 14 |
+
from .base import BuilderResult, DriftConfig, ScenarioSpec
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
|
| 18 |
+
rng = seeded_rng(spec.scenario_id, seed, scale)
|
| 19 |
+
n_authors = max(40, scale // 10)
|
| 20 |
+
n_posts = scale * 4
|
| 21 |
+
|
| 22 |
+
author_names = unique_names(rng, n_authors, prefix="author")
|
| 23 |
+
post_author_idx = zipfian_choices(rng, list(range(n_authors)), n_posts)
|
| 24 |
+
post_author_names = [author_names[i] for i in post_author_idx]
|
| 25 |
+
|
| 26 |
+
conn = duckdb.connect(":memory:")
|
| 27 |
+
conn.execute("CREATE TABLE posts( id BIGINT PRIMARY KEY, author_name VARCHAR, body VARCHAR);")
|
| 28 |
+
conn.executemany(
|
| 29 |
+
"INSERT INTO posts VALUES (?, ?, ?)",
|
| 30 |
+
[(i, n, f"body {i}") for i, n in enumerate(post_author_names, 1)],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
baseline_sql = (
|
| 34 |
+
"SELECT author_name, COUNT(*) AS n_posts "
|
| 35 |
+
"FROM posts GROUP BY author_name "
|
| 36 |
+
"ORDER BY author_name"
|
| 37 |
+
)
|
| 38 |
+
gt_sql_predrift = baseline_sql
|
| 39 |
+
gt_sql_postdrift = (
|
| 40 |
+
"SELECT u.full_name AS author_name, COUNT(*) AS n_posts "
|
| 41 |
+
"FROM posts p JOIN users u ON u.id = p.users_id "
|
| 42 |
+
"GROUP BY u.full_name ORDER BY u.full_name"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
synopsis = (
|
| 46 |
+
"posts(id PK, author_name, body). Under drift, posts.author_name is "
|
| 47 |
+
"deprecated; a new users(id PK, full_name) table is created and "
|
| 48 |
+
"posts gains a users_id FK. Rewrites must JOIN through users."
|
| 49 |
+
)
|
| 50 |
+
return (
|
| 51 |
+
conn,
|
| 52 |
+
baseline_sql,
|
| 53 |
+
gt_sql_predrift,
|
| 54 |
+
gt_sql_postdrift,
|
| 55 |
+
synopsis,
|
| 56 |
+
frozenset({"users", "users_id", "full_name"}),
|
| 57 |
+
frozenset({"author_name"}),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
SPEC = ScenarioSpec(
|
| 62 |
+
scenario_id="10_drift_field_deprecation",
|
| 63 |
+
family="cms",
|
| 64 |
+
tags=frozenset({"drift", "field_deprecation", "fk_backfill", "cms"}),
|
| 65 |
+
drift_config=DriftConfig(
|
| 66 |
+
kind="field_deprecation",
|
| 67 |
+
payload={
|
| 68 |
+
"orig": ("posts", "author_name"),
|
| 69 |
+
"lookup": ("users", "id", "full_name"),
|
| 70 |
+
},
|
| 71 |
+
),
|
| 72 |
+
builder=_build,
|
| 73 |
+
base_scale=1_500,
|
| 74 |
+
)
|
server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLDrift server package."""
|
| 2 |
+
|
| 3 |
+
from server.sql_drift_env_environment import SqlDriftEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["SqlDriftEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI factory for the SQLDrift environment.
|
| 2 |
+
|
| 3 |
+
``create_app()`` returns a fully-wired FastAPI app exposing the
|
| 4 |
+
stateless HTTP routes (``/health``, ``/schema``, ``/reset``, ``/step``)
|
| 5 |
+
and the stateful ``/ws`` WebSocket session. Stateful multi-step
|
| 6 |
+
episodes must go through ``/ws``; each HTTP ``/step`` spawns a
|
| 7 |
+
fresh env instance that is ``close()``-d in ``finally`` (one env per request).
|
| 8 |
+
|
| 9 |
+
``main()`` runs the server with Uvicorn — exported as the
|
| 10 |
+
``[project.scripts] sql-drift-server`` entry point.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from typing import Any
|
| 16 |
+
from uuid import uuid4
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server.http_server import create_app as _openenv_create_app
|
| 19 |
+
|
| 20 |
+
from models import SqlDriftAction, SqlDriftObservation
|
| 21 |
+
from skill_library import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs
|
| 22 |
+
|
| 23 |
+
from . import settings
|
| 24 |
+
from .sql_drift_env_environment import SqlDriftEnvironment
|
| 25 |
+
|
| 26 |
+
ENV_NAME = "sql_drift_env"
|
| 27 |
+
DEFAULT_MAX_CONCURRENT_ENVS = settings.MAX_CONCURRENT_ENVS
|
| 28 |
+
_SESSION_STORE_ROOT = DEFAULT_STORE_DIR / "sessions"
|
| 29 |
+
|
| 30 |
+
# Purge stale session directories left by previous server runs before
|
| 31 |
+
# accepting any traffic. Failures are non-fatal.
|
| 32 |
+
_startup_removed = cleanup_stale_session_dirs(
|
| 33 |
+
_SESSION_STORE_ROOT, settings.SKILL_STORE_SESSION_TTL_HOURS
|
| 34 |
+
)
|
| 35 |
+
if _startup_removed:
|
| 36 |
+
import logging as _logging
|
| 37 |
+
|
| 38 |
+
_logging.getLogger("sql_drift_env.app.server.app").info(
|
| 39 |
+
"startup: removed %d stale session skill-store dirs from %s",
|
| 40 |
+
_startup_removed,
|
| 41 |
+
_SESSION_STORE_ROOT,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _create_server_environment() -> SqlDriftEnvironment:
|
| 46 |
+
"""Build one server-managed env with its own on-disk skill library.
|
| 47 |
+
|
| 48 |
+
``cleanup_on_close=True`` ensures the session directory is deleted when
|
| 49 |
+
the WebSocket session ends, preventing unbounded on-disk session growth.
|
| 50 |
+
"""
|
| 51 |
+
session_dir = _SESSION_STORE_ROOT / uuid4().hex
|
| 52 |
+
return SqlDriftEnvironment(
|
| 53 |
+
skill_store=Store(directory=session_dir),
|
| 54 |
+
cleanup_on_close=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def create_app(max_concurrent_envs: int | None = None) -> Any:
|
| 59 |
+
"""Build the FastAPI app bound to a fresh-env factory per session."""
|
| 60 |
+
if max_concurrent_envs is None:
|
| 61 |
+
max_concurrent_envs = DEFAULT_MAX_CONCURRENT_ENVS
|
| 62 |
+
return _openenv_create_app(
|
| 63 |
+
env=_create_server_environment,
|
| 64 |
+
action_cls=SqlDriftAction,
|
| 65 |
+
observation_cls=SqlDriftObservation,
|
| 66 |
+
env_name=ENV_NAME,
|
| 67 |
+
max_concurrent_envs=max_concurrent_envs,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main(host: str = settings.SERVER_HOST, port: int = settings.SERVER_PORT) -> None:
|
| 72 |
+
"""Uvicorn entry point — matches the [project.scripts] wiring."""
|
| 73 |
+
import uvicorn
|
| 74 |
+
|
| 75 |
+
uvicorn.run(create_app(), host=host, port=port)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Module-level app instance for uvicorn's ``module:attr`` syntax
|
| 79 |
+
# (``uvicorn server.app:app``) and the ``openenv.yaml`` ``app:`` field.
|
| 80 |
+
# Built at import time; safe because the OpenEnv factory only stores the
|
| 81 |
+
# environment factory and instantiates per request / session.
|
| 82 |
+
app = create_app()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
__all__ = ["ENV_NAME", "app", "create_app", "main"]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv export --format requirements-txt --no-dev --no-hashes --no-emit-project --no-annotate -o server/requirements.txt
|
| 3 |
+
aiofile==3.9.0
|
| 4 |
+
annotated-doc==0.0.4
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.13.0
|
| 7 |
+
attrs==26.1.0
|
| 8 |
+
audioop-lts==0.2.2 ; python_full_version >= '3.13'
|
| 9 |
+
authlib==1.7.0
|
| 10 |
+
beartype==0.22.9
|
| 11 |
+
brotli==1.2.0
|
| 12 |
+
cachetools==7.0.6
|
| 13 |
+
caio==0.9.25
|
| 14 |
+
certifi==2026.2.25
|
| 15 |
+
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
|
| 16 |
+
charset-normalizer==3.4.7
|
| 17 |
+
click==8.3.2
|
| 18 |
+
colorama==0.4.6 ; sys_platform == 'win32'
|
| 19 |
+
cryptography==46.0.7
|
| 20 |
+
cyclopts==4.10.2
|
| 21 |
+
distro==1.9.0
|
| 22 |
+
dnspython==2.8.0
|
| 23 |
+
docstring-parser==0.18.0
|
| 24 |
+
docutils==0.22.4
|
| 25 |
+
duckdb==1.5.2
|
| 26 |
+
email-validator==2.3.0
|
| 27 |
+
exceptiongroup==1.3.1
|
| 28 |
+
fastapi==0.136.0
|
| 29 |
+
fastmcp==3.2.4
|
| 30 |
+
filelock==3.29.0
|
| 31 |
+
fsspec==2026.2.0
|
| 32 |
+
gradio==6.13.0
|
| 33 |
+
gradio-client==2.5.0
|
| 34 |
+
griffelib==2.0.2
|
| 35 |
+
groovy==0.1.2
|
| 36 |
+
h11==0.16.0
|
| 37 |
+
hf-gradio==0.4.1
|
| 38 |
+
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 39 |
+
httpcore==1.0.9
|
| 40 |
+
httpx==0.28.1
|
| 41 |
+
httpx-sse==0.4.3
|
| 42 |
+
huggingface-hub==1.11.0
|
| 43 |
+
idna==3.11
|
| 44 |
+
importlib-metadata==8.7.1
|
| 45 |
+
jaraco-classes==3.4.0
|
| 46 |
+
jaraco-context==6.1.2
|
| 47 |
+
jaraco-functools==4.4.0
|
| 48 |
+
jeepney==0.9.0 ; sys_platform == 'linux'
|
| 49 |
+
jinja2==3.1.6
|
| 50 |
+
jiter==0.14.0
|
| 51 |
+
joserfc==1.6.4
|
| 52 |
+
jsonref==1.1.0
|
| 53 |
+
jsonschema==4.26.0
|
| 54 |
+
jsonschema-path==0.4.5
|
| 55 |
+
jsonschema-specifications==2025.9.1
|
| 56 |
+
keyring==25.7.0
|
| 57 |
+
markdown-it-py==4.0.0
|
| 58 |
+
markupsafe==3.0.3
|
| 59 |
+
mcp==1.27.0
|
| 60 |
+
mdurl==0.1.2
|
| 61 |
+
more-itertools==11.0.2
|
| 62 |
+
numpy==2.4.4
|
| 63 |
+
openai==2.32.0
|
| 64 |
+
openapi-pydantic==0.5.1
|
| 65 |
+
openenv-core==0.2.3
|
| 66 |
+
opentelemetry-api==1.41.0
|
| 67 |
+
orjson==3.11.8
|
| 68 |
+
packaging==26.1
|
| 69 |
+
pandas==3.0.2
|
| 70 |
+
pathable==0.5.0
|
| 71 |
+
pillow==12.2.0
|
| 72 |
+
platformdirs==4.9.6
|
| 73 |
+
py-key-value-aio==0.4.4
|
| 74 |
+
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
|
| 75 |
+
pydantic==2.13.3
|
| 76 |
+
pydantic-core==2.46.3
|
| 77 |
+
pydantic-settings==2.14.0
|
| 78 |
+
pydub==0.25.1
|
| 79 |
+
pygments==2.20.0
|
| 80 |
+
pyjwt==2.12.1
|
| 81 |
+
pyperclip==1.11.0
|
| 82 |
+
python-dateutil==2.9.0.post0
|
| 83 |
+
python-dotenv==1.2.2
|
| 84 |
+
python-multipart==0.0.26
|
| 85 |
+
pytz==2026.1.post1
|
| 86 |
+
pywin32==311 ; sys_platform == 'win32'
|
| 87 |
+
pywin32-ctypes==0.2.3 ; sys_platform == 'win32'
|
| 88 |
+
pyyaml==6.0.3
|
| 89 |
+
referencing==0.37.0
|
| 90 |
+
requests==2.33.1
|
| 91 |
+
rich==15.0.0
|
| 92 |
+
rich-rst==1.3.2
|
| 93 |
+
rpds-py==0.30.0
|
| 94 |
+
safehttpx==0.1.7
|
| 95 |
+
secretstorage==3.5.0 ; sys_platform == 'linux'
|
| 96 |
+
semantic-version==2.10.0
|
| 97 |
+
shellingham==1.5.4
|
| 98 |
+
six==1.17.0
|
| 99 |
+
sniffio==1.3.1
|
| 100 |
+
sqlglot==30.6.0
|
| 101 |
+
sse-starlette==3.3.4
|
| 102 |
+
starlette==1.0.0
|
| 103 |
+
tomli==2.4.1
|
| 104 |
+
tomli-w==1.2.0
|
| 105 |
+
tomlkit==0.14.0
|
| 106 |
+
tqdm==4.67.3
|
| 107 |
+
typer==0.24.1
|
| 108 |
+
typing-extensions==4.15.0
|
| 109 |
+
typing-inspection==0.4.2
|
| 110 |
+
tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
|
| 111 |
+
uncalled-for==0.3.1
|
| 112 |
+
urllib3==2.6.3
|
| 113 |
+
uvicorn==0.44.0
|
| 114 |
+
watchfiles==1.1.1
|
| 115 |
+
websockets==16.0
|
| 116 |
+
zipp==3.23.1
|
server/settings.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server/runtime settings loaded from the repo-local `.env`."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from utilities.env_loader import env_float, env_int, env_str
|
| 6 |
+
|
| 7 |
+
SERVER_HOST = env_str("SQL_DRIFT_SERVER_HOST", "0.0.0.0")
|
| 8 |
+
SERVER_PORT = env_int("SQL_DRIFT_SERVER_PORT", 8000, min_value=1, max_value=65535)
|
| 9 |
+
MAX_CONCURRENT_ENVS = env_int("SQL_DRIFT_MAX_CONCURRENT_ENVS", 4, min_value=1)
|
| 10 |
+
|
| 11 |
+
DEFAULT_STEP_BUDGET = env_int("SQL_DRIFT_DEFAULT_STEP_BUDGET", 25, min_value=1)
|
| 12 |
+
MAX_RESULT_ROWS = env_int("SQL_DRIFT_MAX_RESULT_ROWS", 1_000, min_value=1)
|
| 13 |
+
QUERY_TIMEOUT_S = env_float("SQL_DRIFT_QUERY_TIMEOUT_S", 2.0, min_value=0.001)
|
| 14 |
+
|
| 15 |
+
# Session skill-store directories older than this many hours are removed at
|
| 16 |
+
# startup and when the owning environment is closed. Set to 0 to disable
|
| 17 |
+
# TTL-based cleanup (directories will still be removed on close when
|
| 18 |
+
# cleanup_on_close=True is set for a server-managed environment).
|
| 19 |
+
SKILL_STORE_SESSION_TTL_HOURS = env_float(
|
| 20 |
+
"SQL_DRIFT_SKILL_STORE_SESSION_TTL_HOURS", 24.0, min_value=0.0
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"DEFAULT_STEP_BUDGET",
|
| 25 |
+
"MAX_CONCURRENT_ENVS",
|
| 26 |
+
"MAX_RESULT_ROWS",
|
| 27 |
+
"QUERY_TIMEOUT_S",
|
| 28 |
+
"SERVER_HOST",
|
| 29 |
+
"SERVER_PORT",
|
| 30 |
+
"SKILL_STORE_SESSION_TTL_HOURS",
|
| 31 |
+
]
|
server/sql_drift_env_environment.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv ``Environment`` implementation for SQLDrift.
|
| 2 |
+
|
| 3 |
+
Responsibilities:
|
| 4 |
+
|
| 5 |
+
* Own the private :class:`engine.runtime.RuntimeEpisodeState` and the
|
| 6 |
+
composite :class:`engine.reward.SqlDriftRubric` for the current episode.
|
| 7 |
+
* Dispatch each of the eight tool-call payloads to a dedicated
|
| 8 |
+
``_handle_<tool>`` method that returns a typed
|
| 9 |
+
:class:`models.ToolResult` (or :class:`models.ToolError`).
|
| 10 |
+
* Fire drift on a schedule blended with a cooldown: ``max(scheduled,
|
| 11 |
+
first_run_query_step + cooldown)`` before the agent acts on the step
|
| 12 |
+
where drift applies, then recompute the post-drift ground truth hash.
|
| 13 |
+
* Publish public observations (:class:`models.SqlDriftObservation`) and a
|
| 14 |
+
strictly sanitised public state snapshot (:class:`models.SqlDriftState`).
|
| 15 |
+
|
| 16 |
+
Privacy: ``self._runtime`` holds the DuckDB handle, ground-truth hashes,
|
| 17 |
+
baseline runtime, and seed. They stay inside this class; the rubric reads
|
| 18 |
+
them via a closure, and ``env.state`` exposes only a fixed whitelist of fields.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import contextlib
|
| 24 |
+
import math
|
| 25 |
+
import re
|
| 26 |
+
import secrets
|
| 27 |
+
from random import Random
|
| 28 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 29 |
+
|
| 30 |
+
import duckdb
|
| 31 |
+
import sqlglot
|
| 32 |
+
from openenv.core.env_server.interfaces import Environment
|
| 33 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 34 |
+
|
| 35 |
+
from actors import dba_oracle
|
| 36 |
+
from actors.engineering_manager import author_changelog
|
| 37 |
+
from engine.drift import apply_drift
|
| 38 |
+
from engine.profiler import (
|
| 39 |
+
QueryWatchdogEscalationError,
|
| 40 |
+
execute_hash_timed,
|
| 41 |
+
execute_once_timed,
|
| 42 |
+
execute_once_with_columns,
|
| 43 |
+
)
|
| 44 |
+
from engine.reward import (
|
| 45 |
+
SPEEDUP_CAP_FOR_INFTY,
|
| 46 |
+
STEP_REBATE_DESCRIBE_TABLE,
|
| 47 |
+
STEP_REBATE_EXPLAIN_QUERY,
|
| 48 |
+
STEP_REBATE_LIST_TABLES,
|
| 49 |
+
STEP_REBATE_READ_CHANGELOG,
|
| 50 |
+
STEP_REBATE_RUN_QUERY,
|
| 51 |
+
STEP_REBATE_SAMPLE_ROWS,
|
| 52 |
+
SqlDriftRubric,
|
| 53 |
+
canonicalize_sql,
|
| 54 |
+
effective_speedup,
|
| 55 |
+
)
|
| 56 |
+
from engine.runtime import RuntimeEpisodeState
|
| 57 |
+
from engine.verifier import canonical_row_hash
|
| 58 |
+
from models import (
|
| 59 |
+
REWARD_COMPONENT_KEYS,
|
| 60 |
+
ConsultDBAPayload,
|
| 61 |
+
ConsultDBAResult,
|
| 62 |
+
DescribeTablePayload,
|
| 63 |
+
DescribeTableResult,
|
| 64 |
+
EpisodePhase,
|
| 65 |
+
ExplainQueryPayload,
|
| 66 |
+
ExplainQueryResult,
|
| 67 |
+
ListTablesPayload,
|
| 68 |
+
ListTablesResult,
|
| 69 |
+
ReadChangelogPayload,
|
| 70 |
+
ReadChangelogResult,
|
| 71 |
+
RunQueryPayload,
|
| 72 |
+
RunQueryResult,
|
| 73 |
+
SampleRowsPayload,
|
| 74 |
+
SampleRowsResult,
|
| 75 |
+
SqlDriftAction,
|
| 76 |
+
SqlDriftObservation,
|
| 77 |
+
SqlDriftState,
|
| 78 |
+
SubmitRewritePayload,
|
| 79 |
+
SubmitRewriteResult,
|
| 80 |
+
ToolError,
|
| 81 |
+
ToolErrorCode,
|
| 82 |
+
ToolResult,
|
| 83 |
+
)
|
| 84 |
+
from scenarios import REGISTRY, get_spec
|
| 85 |
+
from skill_library import PlaybookEntry, Store, load_all, retrieve
|
| 86 |
+
from utilities.logger import get_module_logger, log_env_reset, log_env_step, log_interaction
|
| 87 |
+
|
| 88 |
+
from . import settings
|
| 89 |
+
|
| 90 |
+
if TYPE_CHECKING:
|
| 91 |
+
from scenarios.base import ScenarioSpec
|
| 92 |
+
|
| 93 |
+
_LOG = get_module_logger(__name__)
|
| 94 |
+
|
| 95 |
+
DEFAULT_STEP_BUDGET: int = settings.DEFAULT_STEP_BUDGET
|
| 96 |
+
MAX_RESULT_ROWS: int = settings.MAX_RESULT_ROWS
|
| 97 |
+
QUERY_TIMEOUT_S: float = settings.QUERY_TIMEOUT_S
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class _ResetOptions(BaseModel):
|
| 101 |
+
model_config = ConfigDict(extra="ignore")
|
| 102 |
+
|
| 103 |
+
scenario_id: str | None = None
|
| 104 |
+
enable_dba_oracle: bool | None = None
|
| 105 |
+
difficulty: Literal["easy", "normal", "hard"] = "normal"
|
| 106 |
+
budget_steps: int = Field(default=DEFAULT_STEP_BUDGET, ge=1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
_READ_ONLY_EXPRESSION_KEYS: frozenset[str] = frozenset({"select", "with"})
|
| 110 |
+
|
| 111 |
+
# DuckDB exposes a family of table-valued functions and scalar helpers
|
| 112 |
+
# that read from the host filesystem or leak introspection state —
|
| 113 |
+
# ``read_csv``, ``read_parquet``, ``read_json``, ``read_text``,
|
| 114 |
+
# ``parquet_metadata``, ``duckdb_secrets``, ``glob``, etc. They are
|
| 115 |
+
# *technically* SELECT-shaped calls so the statement-key check alone
|
| 116 |
+
# admits them. We reject any function whose lowercased name starts with
|
| 117 |
+
# one of these prefixes or exactly matches one of the known-dangerous
|
| 118 |
+
# standalone names. Agent-facing SQL has no legitimate need for any of
|
| 119 |
+
# them — the DuckDB connection is pre-populated by the scenario builder.
|
| 120 |
+
_DENYLIST_PREFIXES: tuple[str, ...] = (
|
| 121 |
+
"read_",
|
| 122 |
+
"write_",
|
| 123 |
+
"copy_",
|
| 124 |
+
"duckdb_",
|
| 125 |
+
"pragma_",
|
| 126 |
+
"sniff_",
|
| 127 |
+
"parquet_",
|
| 128 |
+
"arrow_",
|
| 129 |
+
"json_table",
|
| 130 |
+
"json_each",
|
| 131 |
+
"sqlite_",
|
| 132 |
+
"load_",
|
| 133 |
+
"install_",
|
| 134 |
+
)
|
| 135 |
+
_DENYLIST_EXACT: frozenset[str] = frozenset(
|
| 136 |
+
{
|
| 137 |
+
"glob",
|
| 138 |
+
"attach",
|
| 139 |
+
"detach",
|
| 140 |
+
"checkpoint",
|
| 141 |
+
"force_checkpoint",
|
| 142 |
+
"set_secret",
|
| 143 |
+
"create_secret",
|
| 144 |
+
"drop_secret",
|
| 145 |
+
"enable_profiling",
|
| 146 |
+
"disable_profiling",
|
| 147 |
+
"enable_object_cache",
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _is_denylisted_function_name(name: str) -> bool:
|
| 153 |
+
"""Return True iff ``name`` (case-insensitively) matches a sandbox-escape."""
|
| 154 |
+
lowered = name.lower()
|
| 155 |
+
if lowered in _DENYLIST_EXACT:
|
| 156 |
+
return True
|
| 157 |
+
return any(lowered.startswith(p) for p in _DENYLIST_PREFIXES)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _function_names(node: sqlglot.exp.Func) -> list[str]:
|
| 161 |
+
"""All plausible names to check against the denylist for one AST node.
|
| 162 |
+
|
| 163 |
+
sqlglot lowers a few DuckDB calls into dedicated expression classes
|
| 164 |
+
(``ReadCSV``, ``ReadParquet``, …) whose ``.name`` is actually the
|
| 165 |
+
first positional arg — the file path — not the function name. We
|
| 166 |
+
recover the function name from the class name in that case and fall
|
| 167 |
+
back to ``.name`` for the ``Anonymous`` form that covers everything
|
| 168 |
+
else. Including both lets one denylist lookup cover both lowerings.
|
| 169 |
+
"""
|
| 170 |
+
cls = type(node).__name__
|
| 171 |
+
out: list[str] = []
|
| 172 |
+
# Derive a snake-case function name from the class name. We insert
|
| 173 |
+
# an underscore at two kinds of CamelCase boundaries:
|
| 174 |
+
#
|
| 175 |
+
# * ``aB`` — normal lower-to-upper (``ReadParquet`` → ``read_parquet``)
|
| 176 |
+
# * ``ABc`` — end of an acronym run (``ReadCSVAuto`` → ``read_csv_auto``)
|
| 177 |
+
#
|
| 178 |
+
# Purely-lowercase class names (``Anonymous``) produce no prefix
|
| 179 |
+
# match; we fall through to ``.name`` below for those.
|
| 180 |
+
if cls and cls[0].isupper():
|
| 181 |
+
snake = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", cls).lower()
|
| 182 |
+
out.append(snake)
|
| 183 |
+
name_attr = getattr(node, "name", None)
|
| 184 |
+
if isinstance(name_attr, str) and name_attr:
|
| 185 |
+
out.append(name_attr)
|
| 186 |
+
return out
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
_VALID_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _resolve_timeout_s(timeout_s: float | None) -> float:
|
| 193 |
+
"""Caller-supplied per-step timeout or the module default.
|
| 194 |
+
|
| 195 |
+
``timeout_s`` is accepted on every OpenEnv ``step()`` (the abstract
|
| 196 |
+
base mandates the keyword). When the caller provides a positive
|
| 197 |
+
value we honour it as the wall-clock budget for any DuckDB query
|
| 198 |
+
this step runs; ``None`` and non-positive values fall back to the
|
| 199 |
+
module-level :data:`QUERY_TIMEOUT_S` so a mis-configured client
|
| 200 |
+
cannot silently disable the watchdog.
|
| 201 |
+
"""
|
| 202 |
+
if timeout_s is None or timeout_s <= 0:
|
| 203 |
+
return QUERY_TIMEOUT_S
|
| 204 |
+
return float(timeout_s)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _initial_schema_synopsis(spec: ScenarioSpec, synopsis: str) -> str:
|
| 208 |
+
"""Reset-time synopsis with future drift details removed.
|
| 209 |
+
|
| 210 |
+
Drift scenarios should not reveal the exact schema/business-rule
|
| 211 |
+
change before the changelog is published at runtime. We therefore
|
| 212 |
+
trim the authored synopsis at the first ``" Under drift"`` clause on
|
| 213 |
+
reset and only surface the pre-drift schema shape.
|
| 214 |
+
"""
|
| 215 |
+
if spec.drift_config is None:
|
| 216 |
+
return synopsis
|
| 217 |
+
predrift, marker, _ = synopsis.partition(" Under drift")
|
| 218 |
+
return predrift if marker else synopsis
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _validate_read_only_sql(sql: str) -> None:
|
| 222 |
+
"""Reject anything that isn't a single-statement read-only SELECT/CTE.
|
| 223 |
+
|
| 224 |
+
Raises ``ValueError`` so the caller can translate to a typed
|
| 225 |
+
:class:`models.ToolError` with :attr:`ToolErrorCode.INVALID_TOOL_ARGUMENT`.
|
| 226 |
+
This is the only place that mediates what the policy may execute;
|
| 227 |
+
scenario builders and drift DDL call DuckDB directly with privileged
|
| 228 |
+
SQL and deliberately bypass this check.
|
| 229 |
+
|
| 230 |
+
Beyond the statement-level gate, this walker also rejects two
|
| 231 |
+
sandbox-escape vectors that would otherwise ride along inside a
|
| 232 |
+
perfectly-shaped SELECT:
|
| 233 |
+
|
| 234 |
+
1. Table-valued functions that read from the host filesystem
|
| 235 |
+
(``read_csv``, ``read_parquet``, ``read_json_auto``, ``glob``,
|
| 236 |
+
``read_text``, …) or leak engine introspection (``duckdb_secrets``
|
| 237 |
+
carries credentials; ``duckdb_settings`` /``duckdb_functions``
|
| 238 |
+
can enumerate available exploits). See :data:`_DENYLIST_PREFIXES`
|
| 239 |
+
/ :data:`_DENYLIST_EXACT`.
|
| 240 |
+
2. ``SELECT * FROM 'path/to/x.csv'`` — DuckDB treats a bare string
|
| 241 |
+
literal in a FROM clause as a filesystem path and auto-detects
|
| 242 |
+
the format. There is no function node to inspect in this form,
|
| 243 |
+
so we separately reject any :class:`sqlglot.exp.Table` whose
|
| 244 |
+
backing expression is a string literal.
|
| 245 |
+
"""
|
| 246 |
+
try:
|
| 247 |
+
statements = sqlglot.parse(sql, dialect="duckdb")
|
| 248 |
+
except sqlglot.errors.ParseError as exc:
|
| 249 |
+
raise ValueError(f"SQL failed to parse: {exc}") from exc
|
| 250 |
+
|
| 251 |
+
non_empty = [s for s in statements if s is not None]
|
| 252 |
+
if len(non_empty) != 1:
|
| 253 |
+
raise ValueError("multi-statement SQL is not allowed; submit one SELECT")
|
| 254 |
+
expr = non_empty[0]
|
| 255 |
+
if expr.key not in _READ_ONLY_EXPRESSION_KEYS:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"only read-only SELECT/CTE queries are allowed (got {expr.key.upper()} statement)"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
for node in expr.walk():
|
| 261 |
+
# (1) Function-valued sandbox escapes. Inspect both the class
|
| 262 |
+
# name (catches ``ReadCSV`` / ``ReadParquet`` lowerings where
|
| 263 |
+
# ``.name`` holds the file path, not the function name) and
|
| 264 |
+
# ``.name`` (catches the generic ``Anonymous`` form).
|
| 265 |
+
if isinstance(node, sqlglot.exp.Func):
|
| 266 |
+
for fn_name in _function_names(node):
|
| 267 |
+
if _is_denylisted_function_name(fn_name):
|
| 268 |
+
raise ValueError(
|
| 269 |
+
f"function {fn_name!r} is not allowed — agent-facing SQL may "
|
| 270 |
+
"only touch the scenario's in-memory tables"
|
| 271 |
+
)
|
| 272 |
+
# (2) Bare-path FROM form: ``SELECT * FROM 'x.csv'`` or
|
| 273 |
+
# ``SELECT * FROM '/etc/passwd'``. sqlglot normalises both
|
| 274 |
+
# single- and double-quoted identifiers to
|
| 275 |
+
# ``Identifier(quoted=True)``, so we can't rely on the quote
|
| 276 |
+
# flavour to distinguish a file path from a legitimately-quoted
|
| 277 |
+
# table name. Instead we require every agent-facing table name
|
| 278 |
+
# to be a valid unquoted SQL identifier — the scenarios never
|
| 279 |
+
# emit anything else, and paths always contain ``/``, ``.`` or
|
| 280 |
+
# ``~`` which fail the identifier regex.
|
| 281 |
+
if isinstance(node, sqlglot.exp.Table):
|
| 282 |
+
inner = node.this
|
| 283 |
+
if isinstance(inner, sqlglot.exp.Identifier):
|
| 284 |
+
ident_name = inner.name
|
| 285 |
+
if ident_name and not _VALID_IDENTIFIER_RE.match(ident_name):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"table identifier {ident_name!r} is not a valid unquoted SQL "
|
| 288 |
+
"name — reading from file paths or other engine-specific "
|
| 289 |
+
"resources is not allowed"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class SqlDriftEnvironment(Environment[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
|
| 294 |
+
"""OpenEnv environment for SQL repair + optimization under schema drift."""
|
| 295 |
+
|
| 296 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
skill_store: Store | None = None,
|
| 301 |
+
cleanup_on_close: bool = False,
|
| 302 |
+
) -> None:
|
| 303 |
+
self._runtime: RuntimeEpisodeState | None = None
|
| 304 |
+
self._skill_store: Store | None = skill_store
|
| 305 |
+
# When True, the skill-store directory is deleted when close() is called.
|
| 306 |
+
# Set this for server-managed per-session stores so disk usage doesn't grow
|
| 307 |
+
# monotonically; see design/codereview.md (session store issue).
|
| 308 |
+
self._cleanup_on_close: bool = cleanup_on_close
|
| 309 |
+
super().__init__(
|
| 310 |
+
rubric=SqlDriftRubric(ctx_provider=lambda: self._require_runtime()),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# ------------------------------------------------------------------
|
| 314 |
+
# OpenEnv contract
|
| 315 |
+
# ------------------------------------------------------------------
|
| 316 |
+
|
| 317 |
+
@log_env_reset
|
| 318 |
+
def reset(
|
| 319 |
+
self,
|
| 320 |
+
seed: int | None = None,
|
| 321 |
+
episode_id: str | None = None,
|
| 322 |
+
**kwargs: Any,
|
| 323 |
+
) -> SqlDriftObservation:
|
| 324 |
+
options = _ResetOptions.model_validate(kwargs)
|
| 325 |
+
scenario_id = options.scenario_id
|
| 326 |
+
enable_dba_oracle = dba_oracle.is_enabled(options.enable_dba_oracle)
|
| 327 |
+
difficulty = options.difficulty
|
| 328 |
+
budget_steps = options.budget_steps
|
| 329 |
+
|
| 330 |
+
if seed is None:
|
| 331 |
+
seed = secrets.randbits(31)
|
| 332 |
+
if episode_id is None:
|
| 333 |
+
episode_id = f"ep-{seed:08x}"
|
| 334 |
+
if scenario_id is None:
|
| 335 |
+
scenario_id = self._pick_scenario_for_seed(seed)
|
| 336 |
+
|
| 337 |
+
spec = get_spec(scenario_id)
|
| 338 |
+
instance = spec.materialize(seed, difficulty=difficulty)
|
| 339 |
+
|
| 340 |
+
drift_scheduled_step: int | None = None
|
| 341 |
+
if instance.drift_config is not None:
|
| 342 |
+
drift_scheduled_step = Random(seed).randint(
|
| 343 |
+
instance.drift_config.min_step,
|
| 344 |
+
instance.drift_config.max_step,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self._close_existing_runtime()
|
| 348 |
+
self._runtime = RuntimeEpisodeState(
|
| 349 |
+
episode_id=episode_id,
|
| 350 |
+
seed=seed,
|
| 351 |
+
scenario_id=scenario_id,
|
| 352 |
+
instance=instance,
|
| 353 |
+
conn=instance.conn,
|
| 354 |
+
gt_result_hash_predrift=instance.gt_result_hash_predrift,
|
| 355 |
+
gt_result_hash_postdrift=None,
|
| 356 |
+
baseline_runtime_ms=instance.baseline_runtime_ms,
|
| 357 |
+
baseline_tokens=instance.baseline_tokens,
|
| 358 |
+
baseline_sql_canonical=canonicalize_sql(instance.baseline_sql),
|
| 359 |
+
baseline_postdrift_raises=False,
|
| 360 |
+
drift_scheduled_step=drift_scheduled_step,
|
| 361 |
+
budget_steps=budget_steps,
|
| 362 |
+
dba_oracle_enabled=enable_dba_oracle,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
self._reset_rubric()
|
| 366 |
+
|
| 367 |
+
learned_hints = kwargs.get("learned_hints")
|
| 368 |
+
if learned_hints is None:
|
| 369 |
+
learned_hints = self._render_learned_hints(spec, include_drift_cards=False)
|
| 370 |
+
if len(learned_hints) > 800:
|
| 371 |
+
learned_hints = learned_hints[:800]
|
| 372 |
+
|
| 373 |
+
rt = self._require_runtime()
|
| 374 |
+
return SqlDriftObservation(
|
| 375 |
+
step=0,
|
| 376 |
+
phase=EpisodePhase.DIAGNOSE,
|
| 377 |
+
last_tool=None,
|
| 378 |
+
tool_result=None,
|
| 379 |
+
drift_fired=False,
|
| 380 |
+
drift_acknowledged=False,
|
| 381 |
+
learned_hints=learned_hints,
|
| 382 |
+
baseline_sql=instance.baseline_sql,
|
| 383 |
+
schema_synopsis=_initial_schema_synopsis(spec, instance.schema_synopsis),
|
| 384 |
+
budget_steps_remaining=rt.budget_steps_remaining,
|
| 385 |
+
reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
|
| 386 |
+
done=False,
|
| 387 |
+
reward=None,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
@log_env_step
|
| 391 |
+
def step(
|
| 392 |
+
self,
|
| 393 |
+
action: SqlDriftAction,
|
| 394 |
+
timeout_s: float | None = None,
|
| 395 |
+
**kwargs: Any,
|
| 396 |
+
) -> SqlDriftObservation:
|
| 397 |
+
rt = self._require_runtime()
|
| 398 |
+
if rt.submitted or rt.budget_steps_remaining <= 0:
|
| 399 |
+
raise ValueError("Episode is already finished; call reset() to start a new episode.")
|
| 400 |
+
rt.step_count += 1
|
| 401 |
+
rt.last_step_was_tool_error = False
|
| 402 |
+
rt.last_step_was_repeat_failing_query = False
|
| 403 |
+
rt.last_step_repeat_failing_query_count = 0
|
| 404 |
+
rt.last_step_productive_rebate = 0.0
|
| 405 |
+
|
| 406 |
+
self._maybe_fire_drift()
|
| 407 |
+
|
| 408 |
+
effective_timeout_s = _resolve_timeout_s(timeout_s)
|
| 409 |
+
try:
|
| 410 |
+
tool_result = self._dispatch(action, timeout_s=effective_timeout_s)
|
| 411 |
+
except QueryWatchdogEscalationError:
|
| 412 |
+
rt.connection_poisoned = True
|
| 413 |
+
rt.phase = EpisodePhase.FINALIZE
|
| 414 |
+
rt.step_count = max(rt.step_count, rt.budget_steps)
|
| 415 |
+
_LOG.error("episode %s aborted after watchdog escalation", rt.episode_id)
|
| 416 |
+
raise
|
| 417 |
+
rt.last_step_was_tool_error = isinstance(tool_result, ToolError)
|
| 418 |
+
if rt.last_step_was_tool_error:
|
| 419 |
+
rt.consecutive_tool_errors += 1
|
| 420 |
+
else:
|
| 421 |
+
rt.consecutive_tool_errors = 0
|
| 422 |
+
|
| 423 |
+
done = rt.submitted or rt.budget_steps_remaining <= 0
|
| 424 |
+
|
| 425 |
+
obs = SqlDriftObservation(
|
| 426 |
+
step=rt.step_count,
|
| 427 |
+
phase=rt.phase,
|
| 428 |
+
last_tool=action.tool,
|
| 429 |
+
tool_result=tool_result,
|
| 430 |
+
drift_fired=rt.drift_fired,
|
| 431 |
+
drift_acknowledged=rt.drift_acknowledged,
|
| 432 |
+
learned_hints="",
|
| 433 |
+
baseline_sql="",
|
| 434 |
+
schema_synopsis="",
|
| 435 |
+
budget_steps_remaining=rt.budget_steps_remaining,
|
| 436 |
+
reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
|
| 437 |
+
done=done,
|
| 438 |
+
reward=None,
|
| 439 |
+
)
|
| 440 |
+
if rt.drift_acknowledged:
|
| 441 |
+
spec = get_spec(rt.scenario_id)
|
| 442 |
+
obs.learned_hints = self._render_learned_hints(spec, include_drift_cards=True)
|
| 443 |
+
|
| 444 |
+
obs.reward = self._apply_rubric(action, obs)
|
| 445 |
+
if self.rubric is not None:
|
| 446 |
+
obs.reward_components = self.rubric.component_scores()
|
| 447 |
+
|
| 448 |
+
if done and rt.submitted:
|
| 449 |
+
self._maybe_persist_learned_entry()
|
| 450 |
+
return obs
|
| 451 |
+
|
| 452 |
+
def render(self) -> dict[str, Any]:
|
| 453 |
+
"""Render the current public state and log the render interaction."""
|
| 454 |
+
rt = self._require_runtime()
|
| 455 |
+
state = self.state
|
| 456 |
+
payload = state.model_dump(mode="json")
|
| 457 |
+
log_interaction(
|
| 458 |
+
event_type="render",
|
| 459 |
+
agent_id=rt.episode_id,
|
| 460 |
+
observation_returned=payload,
|
| 461 |
+
done=rt.submitted or rt.budget_steps_remaining <= 0,
|
| 462 |
+
)
|
| 463 |
+
return payload
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def state(self) -> SqlDriftState:
|
| 467 |
+
"""Sanitised public state snapshot (explicit whitelist)."""
|
| 468 |
+
rt = self._require_runtime()
|
| 469 |
+
return SqlDriftState(
|
| 470 |
+
episode_id=rt.episode_id,
|
| 471 |
+
step_count=rt.step_count,
|
| 472 |
+
scenario_id=rt.scenario_id,
|
| 473 |
+
phase=rt.phase,
|
| 474 |
+
budget_steps_remaining=rt.budget_steps_remaining,
|
| 475 |
+
drift_fired=rt.drift_fired,
|
| 476 |
+
consultations_used=rt.consultations_used,
|
| 477 |
+
submitted=rt.submitted,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def effective_speedup(self) -> float | None:
|
| 481 |
+
"""Return the current episode's effective speedup, if any."""
|
| 482 |
+
rt = self._runtime
|
| 483 |
+
if rt is None:
|
| 484 |
+
return None
|
| 485 |
+
return effective_speedup(rt)
|
| 486 |
+
|
| 487 |
+
def close(self) -> None:
|
| 488 |
+
self._close_existing_runtime()
|
| 489 |
+
if self._cleanup_on_close and self._skill_store is not None:
|
| 490 |
+
import shutil
|
| 491 |
+
|
| 492 |
+
store_dir = self._skill_store.dir
|
| 493 |
+
shutil.rmtree(store_dir, ignore_errors=True)
|
| 494 |
+
|
| 495 |
+
# ------------------------------------------------------------------
|
| 496 |
+
# Skill-library wiring
|
| 497 |
+
# ------------------------------------------------------------------
|
| 498 |
+
|
| 499 |
+
def _render_learned_hints(self, spec: ScenarioSpec, *, include_drift_cards: bool = True) -> str:
|
| 500 |
+
playbook, drift_cards = load_all(self._skill_store)
|
| 501 |
+
drift_kind = None
|
| 502 |
+
if include_drift_cards and spec.drift_config is not None:
|
| 503 |
+
drift_kind = spec.drift_config.kind
|
| 504 |
+
result = retrieve(
|
| 505 |
+
query_tags=spec.tags,
|
| 506 |
+
drift_kind=drift_kind,
|
| 507 |
+
playbook=playbook,
|
| 508 |
+
drift_cards=drift_cards,
|
| 509 |
+
)
|
| 510 |
+
return result.render(max_chars=800)
|
| 511 |
+
|
| 512 |
+
def _maybe_persist_learned_entry(self) -> None:
|
| 513 |
+
"""Append a PlaybookEntry on terminal success with a meaningful speedup.
|
| 514 |
+
|
| 515 |
+
Failures to persist are logged but never re-raised: a training
|
| 516 |
+
rollout should not crash because the on-disk playbook is under
|
| 517 |
+
contention. The skill store itself is crash-safe (atomic writes
|
| 518 |
+
+ file-lock) so at-most-once semantics are sufficient here.
|
| 519 |
+
"""
|
| 520 |
+
if self._skill_store is None:
|
| 521 |
+
return
|
| 522 |
+
rt = self._require_runtime()
|
| 523 |
+
if not rt.submitted:
|
| 524 |
+
return
|
| 525 |
+
if self.rubric is None:
|
| 526 |
+
return
|
| 527 |
+
scores = self.rubric.component_scores()
|
| 528 |
+
if scores.get("r_correct", 0.0) < 1.0:
|
| 529 |
+
return
|
| 530 |
+
spec = get_spec(rt.scenario_id)
|
| 531 |
+
raw_speedup = effective_speedup(rt)
|
| 532 |
+
# effective_speedup cannot return None here — rt.submitted is True
|
| 533 |
+
# so submitted_runtime_ms is populated — but we guard defensively.
|
| 534 |
+
# ``+∞`` (drift invalidated the baseline) is capped so the on-disk
|
| 535 |
+
# playbook doesn't serialize ``Infinity``, which would round-trip
|
| 536 |
+
# as a JSON parse error on load.
|
| 537 |
+
if raw_speedup is None or math.isinf(raw_speedup):
|
| 538 |
+
speedup_val = float(SPEEDUP_CAP_FOR_INFTY)
|
| 539 |
+
else:
|
| 540 |
+
speedup_val = float(raw_speedup)
|
| 541 |
+
entry = PlaybookEntry(
|
| 542 |
+
tag_set=spec.tags,
|
| 543 |
+
before_snippet=rt.instance.baseline_sql[:200],
|
| 544 |
+
after_snippet=(rt.submitted_sql or "")[:200],
|
| 545 |
+
avg_speedup=speedup_val,
|
| 546 |
+
scenario_family=spec.family,
|
| 547 |
+
source="learned",
|
| 548 |
+
)
|
| 549 |
+
try:
|
| 550 |
+
self._skill_store.append_playbook(entry)
|
| 551 |
+
except Exception as exc:
|
| 552 |
+
_LOG.warning("skill-library append_playbook failed: %s", exc)
|
| 553 |
+
|
| 554 |
+
# ------------------------------------------------------------------
|
| 555 |
+
# Internal helpers
|
| 556 |
+
# ------------------------------------------------------------------
|
| 557 |
+
|
| 558 |
+
def _grant_step_rebate_once(self, *, attr: str, rebate: float) -> None:
|
| 559 |
+
rt = self._require_runtime()
|
| 560 |
+
if getattr(rt, attr):
|
| 561 |
+
return
|
| 562 |
+
setattr(rt, attr, True)
|
| 563 |
+
rt.last_step_productive_rebate += rebate
|
| 564 |
+
|
| 565 |
+
def _grant_step_rebate_for_table(
|
| 566 |
+
self, *, rewarded_tables_attr: str, table: str, rebate: float
|
| 567 |
+
) -> None:
|
| 568 |
+
rt = self._require_runtime()
|
| 569 |
+
rewarded = getattr(rt, rewarded_tables_attr)
|
| 570 |
+
if table in rewarded:
|
| 571 |
+
return
|
| 572 |
+
rewarded.add(table)
|
| 573 |
+
rt.last_step_productive_rebate += rebate
|
| 574 |
+
|
| 575 |
+
@staticmethod
|
| 576 |
+
def _pick_scenario_for_seed(seed: int) -> str:
|
| 577 |
+
"""Deterministic round-robin over the sorted scenario registry."""
|
| 578 |
+
ids = sorted(REGISTRY)
|
| 579 |
+
if not ids:
|
| 580 |
+
raise RuntimeError("no scenarios registered")
|
| 581 |
+
return ids[seed % len(ids)]
|
| 582 |
+
|
| 583 |
+
def _require_runtime(self) -> RuntimeEpisodeState:
|
| 584 |
+
if self._runtime is None:
|
| 585 |
+
raise RuntimeError("SqlDriftEnvironment.reset() must be called before step()/state.")
|
| 586 |
+
return self._runtime
|
| 587 |
+
|
| 588 |
+
def _close_existing_runtime(self) -> None:
|
| 589 |
+
if self._runtime is not None:
|
| 590 |
+
if self._runtime.connection_poisoned:
|
| 591 |
+
_LOG.error(
|
| 592 |
+
"skipping close for poisoned DuckDB connection in episode %s",
|
| 593 |
+
self._runtime.episode_id,
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
with contextlib.suppress(duckdb.Error):
|
| 597 |
+
self._runtime.conn.close()
|
| 598 |
+
self._runtime = None
|
| 599 |
+
|
| 600 |
+
def _maybe_fire_drift(self) -> None:
|
| 601 |
+
"""Apply drift when the step index crosses the schedule/cooldown threshold."""
|
| 602 |
+
rt = self._require_runtime()
|
| 603 |
+
if rt.drift_fired:
|
| 604 |
+
return
|
| 605 |
+
if rt.drift_scheduled_step is None:
|
| 606 |
+
return
|
| 607 |
+
if rt.first_run_query_step is None:
|
| 608 |
+
return
|
| 609 |
+
cfg = rt.instance.drift_config
|
| 610 |
+
assert cfg is not None
|
| 611 |
+
minimum = max(rt.drift_scheduled_step, rt.first_run_query_step + cfg.cooldown_steps)
|
| 612 |
+
if rt.step_count < minimum:
|
| 613 |
+
return
|
| 614 |
+
self._fire_drift()
|
| 615 |
+
|
| 616 |
+
def _fire_drift(self) -> None:
|
| 617 |
+
"""Apply drift, author a changelog, and resolve the post-drift GT hash.
|
| 618 |
+
|
| 619 |
+
Failure to recompute the post-drift GT hash is an authoring bug
|
| 620 |
+
(the scenario's ``gt_sql_postdrift`` must execute against the
|
| 621 |
+
just-mutated DB) and we re-raise loudly so it cannot silently
|
| 622 |
+
make every post-drift submission score ``r_correct=0``.
|
| 623 |
+
"""
|
| 624 |
+
rt = self._require_runtime()
|
| 625 |
+
cfg = rt.instance.drift_config
|
| 626 |
+
assert cfg is not None
|
| 627 |
+
apply_drift(rt.conn, cfg.kind, cfg.payload)
|
| 628 |
+
rt.drift_fired_step = rt.step_count
|
| 629 |
+
rt.phase = EpisodePhase.DRIFT_RECOVERY
|
| 630 |
+
rt.changelog_entries.append(author_changelog(cfg))
|
| 631 |
+
|
| 632 |
+
try:
|
| 633 |
+
rt.conn.execute(rt.instance.baseline_sql).fetchall()
|
| 634 |
+
rt.baseline_postdrift_raises = False
|
| 635 |
+
except duckdb.Error:
|
| 636 |
+
rt.baseline_postdrift_raises = True
|
| 637 |
+
|
| 638 |
+
if rt.instance.gt_sql_postdrift is not None:
|
| 639 |
+
try:
|
| 640 |
+
rows = rt.conn.execute(rt.instance.gt_sql_postdrift).fetchall()
|
| 641 |
+
except duckdb.Error as exc:
|
| 642 |
+
raise RuntimeError(
|
| 643 |
+
f"scenario {rt.scenario_id!r}: authored gt_sql_postdrift failed "
|
| 644 |
+
f"after drift: {exc}"
|
| 645 |
+
) from exc
|
| 646 |
+
rt.gt_result_hash_postdrift = canonical_row_hash(rows)
|
| 647 |
+
|
| 648 |
+
# ------------------------------------------------------------------
|
| 649 |
+
# Tool dispatch
|
| 650 |
+
# ------------------------------------------------------------------
|
| 651 |
+
|
| 652 |
+
def _dispatch(self, action: SqlDriftAction, *, timeout_s: float) -> ToolResult:
|
| 653 |
+
payload = action.payload
|
| 654 |
+
try:
|
| 655 |
+
if isinstance(payload, ListTablesPayload):
|
| 656 |
+
return self._handle_list_tables()
|
| 657 |
+
if isinstance(payload, DescribeTablePayload):
|
| 658 |
+
return self._handle_describe_table(payload)
|
| 659 |
+
if isinstance(payload, SampleRowsPayload):
|
| 660 |
+
return self._handle_sample_rows(payload)
|
| 661 |
+
if isinstance(payload, RunQueryPayload):
|
| 662 |
+
return self._handle_run_query(payload, timeout_s=timeout_s)
|
| 663 |
+
if isinstance(payload, ExplainQueryPayload):
|
| 664 |
+
return self._handle_explain_query(payload, timeout_s=timeout_s)
|
| 665 |
+
if isinstance(payload, ReadChangelogPayload):
|
| 666 |
+
return self._handle_read_changelog()
|
| 667 |
+
if isinstance(payload, SubmitRewritePayload):
|
| 668 |
+
return self._handle_submit_rewrite(payload, timeout_s=timeout_s)
|
| 669 |
+
if isinstance(payload, ConsultDBAPayload):
|
| 670 |
+
return self._handle_consult_dba(payload)
|
| 671 |
+
except duckdb.Error as exc:
|
| 672 |
+
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
|
| 673 |
+
except TimeoutError as exc:
|
| 674 |
+
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
|
| 675 |
+
# Unreachable — the discriminated-union validator rejects unknown payloads.
|
| 676 |
+
return ToolError(
|
| 677 |
+
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
|
| 678 |
+
message=f"unknown payload type: {type(payload).__name__}",
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
def _handle_list_tables(self) -> ListTablesResult:
|
| 682 |
+
rt = self._require_runtime()
|
| 683 |
+
rows = rt.conn.execute(
|
| 684 |
+
"SELECT table_name FROM information_schema.tables "
|
| 685 |
+
"WHERE table_schema = 'main' ORDER BY table_name"
|
| 686 |
+
).fetchall()
|
| 687 |
+
self._grant_step_rebate_once(attr="listed_tables_rewarded", rebate=STEP_REBATE_LIST_TABLES)
|
| 688 |
+
self._mark_diagnostic()
|
| 689 |
+
return ListTablesResult(tables=[r[0] for r in rows])
|
| 690 |
+
|
| 691 |
+
def _handle_describe_table(
|
| 692 |
+
self, payload: DescribeTablePayload
|
| 693 |
+
) -> DescribeTableResult | ToolError:
|
| 694 |
+
rt = self._require_runtime()
|
| 695 |
+
rows = rt.conn.execute(
|
| 696 |
+
"SELECT column_name, data_type FROM information_schema.columns "
|
| 697 |
+
"WHERE table_name = ? ORDER BY ordinal_position",
|
| 698 |
+
[payload.table],
|
| 699 |
+
).fetchall()
|
| 700 |
+
if not rows:
|
| 701 |
+
return ToolError(
|
| 702 |
+
code=ToolErrorCode.UNKNOWN_TABLE,
|
| 703 |
+
message=f"unknown table: {payload.table}",
|
| 704 |
+
)
|
| 705 |
+
self._grant_step_rebate_for_table(
|
| 706 |
+
rewarded_tables_attr="described_tables_rewarded",
|
| 707 |
+
table=payload.table,
|
| 708 |
+
rebate=STEP_REBATE_DESCRIBE_TABLE,
|
| 709 |
+
)
|
| 710 |
+
self._mark_diagnostic()
|
| 711 |
+
return DescribeTableResult(
|
| 712 |
+
table=payload.table,
|
| 713 |
+
columns=[{"name": r[0], "type": r[1]} for r in rows],
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
def _handle_sample_rows(self, payload: SampleRowsPayload) -> SampleRowsResult | ToolError:
|
| 717 |
+
rt = self._require_runtime()
|
| 718 |
+
exists = rt.conn.execute(
|
| 719 |
+
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?",
|
| 720 |
+
[payload.table],
|
| 721 |
+
).fetchone()
|
| 722 |
+
if not exists or exists[0] == 0:
|
| 723 |
+
return ToolError(
|
| 724 |
+
code=ToolErrorCode.UNKNOWN_TABLE,
|
| 725 |
+
message=f"unknown table: {payload.table}",
|
| 726 |
+
)
|
| 727 |
+
cur = rt.conn.execute(f'SELECT * FROM "{payload.table}" LIMIT {payload.limit}')
|
| 728 |
+
columns = [d[0] for d in cur.description] if cur.description else []
|
| 729 |
+
rows = [list(r) for r in cur.fetchall()]
|
| 730 |
+
self._grant_step_rebate_for_table(
|
| 731 |
+
rewarded_tables_attr="sampled_tables_rewarded",
|
| 732 |
+
table=payload.table,
|
| 733 |
+
rebate=STEP_REBATE_SAMPLE_ROWS,
|
| 734 |
+
)
|
| 735 |
+
self._mark_diagnostic()
|
| 736 |
+
return SampleRowsResult(table=payload.table, columns=columns, rows=rows)
|
| 737 |
+
|
| 738 |
+
def _handle_run_query(
|
| 739 |
+
self, payload: RunQueryPayload, *, timeout_s: float
|
| 740 |
+
) -> RunQueryResult | ToolError:
|
| 741 |
+
rt = self._require_runtime()
|
| 742 |
+
sql = payload.sql
|
| 743 |
+
try:
|
| 744 |
+
_validate_read_only_sql(sql)
|
| 745 |
+
except ValueError as exc:
|
| 746 |
+
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
|
| 747 |
+
|
| 748 |
+
# Drift timing: after a valid
|
| 749 |
+
# ``run_query`` attempt, the pre-drift probe invariant is
|
| 750 |
+
# satisfied regardless of whether the execution ultimately
|
| 751 |
+
# returned rows, raised, or was capped for size. Assigning
|
| 752 |
+
# *before* execution means truncation, DB errors, and timeouts
|
| 753 |
+
# can no longer suppress drift firing in later steps.
|
| 754 |
+
if rt.first_run_query_step is None:
|
| 755 |
+
rt.first_run_query_step = rt.step_count
|
| 756 |
+
|
| 757 |
+
try:
|
| 758 |
+
result = execute_once_with_columns(
|
| 759 |
+
rt.conn, sql, timeout_s=timeout_s, max_rows=MAX_RESULT_ROWS
|
| 760 |
+
)
|
| 761 |
+
except TimeoutError as exc:
|
| 762 |
+
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
|
| 763 |
+
except duckdb.Error as exc:
|
| 764 |
+
# Canonicalize *before* hashing so whitespace-/case-only
|
| 765 |
+
# variants of the same broken query count as the same repeat
|
| 766 |
+
# offence. canonicalize_sql falls back to a whitespace fold
|
| 767 |
+
# for SQL that sqlglot can't parse — still normalises the
|
| 768 |
+
# vast majority of "retried the same typo" cases.
|
| 769 |
+
failure_hash = canonical_row_hash([(canonicalize_sql(sql),)])
|
| 770 |
+
count = rt.failed_query_counts.get(failure_hash, 0) + 1
|
| 771 |
+
rt.failed_query_counts[failure_hash] = count
|
| 772 |
+
rt.failed_query_hashes.add(failure_hash)
|
| 773 |
+
rt.last_step_repeat_failing_query_count = count
|
| 774 |
+
rt.last_step_was_repeat_failing_query = count > 1
|
| 775 |
+
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
|
| 776 |
+
|
| 777 |
+
if result.truncated:
|
| 778 |
+
return ToolError(
|
| 779 |
+
code=ToolErrorCode.RESULT_TOO_LARGE,
|
| 780 |
+
message=(
|
| 781 |
+
f"result exceeded {MAX_RESULT_ROWS}-row cap — narrow the "
|
| 782 |
+
"projection, add a LIMIT, or aggregate"
|
| 783 |
+
),
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
self._grant_step_rebate_once(attr="run_query_rewarded", rebate=STEP_REBATE_RUN_QUERY)
|
| 787 |
+
self._mark_diagnostic()
|
| 788 |
+
return RunQueryResult(
|
| 789 |
+
columns=result.columns,
|
| 790 |
+
rows=[list(r) for r in result.rows],
|
| 791 |
+
runtime_ms=result.elapsed_ms,
|
| 792 |
+
row_count=len(result.rows),
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
def _handle_explain_query(
|
| 796 |
+
self, payload: ExplainQueryPayload, *, timeout_s: float
|
| 797 |
+
) -> ExplainQueryResult | ToolError:
|
| 798 |
+
rt = self._require_runtime()
|
| 799 |
+
try:
|
| 800 |
+
_validate_read_only_sql(payload.sql)
|
| 801 |
+
except ValueError as exc:
|
| 802 |
+
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
|
| 803 |
+
# EXPLAIN is plan-only (no data materialisation) but we still
|
| 804 |
+
# route it through the watchdog so a pathological query cannot
|
| 805 |
+
# burn the step budget past the caller's wall-clock deadline.
|
| 806 |
+
explain_rows, _ = execute_once_timed(rt.conn, f"EXPLAIN {payload.sql}", timeout_s=timeout_s)
|
| 807 |
+
plan = "\n".join(str(r[-1]) if r else "" for r in explain_rows)
|
| 808 |
+
self._grant_step_rebate_once(
|
| 809 |
+
attr="explain_query_rewarded", rebate=STEP_REBATE_EXPLAIN_QUERY
|
| 810 |
+
)
|
| 811 |
+
self._mark_diagnostic()
|
| 812 |
+
return ExplainQueryResult(plan=plan[:10_000])
|
| 813 |
+
|
| 814 |
+
def _handle_read_changelog(self) -> ReadChangelogResult:
|
| 815 |
+
rt = self._require_runtime()
|
| 816 |
+
if rt.changelog_entries:
|
| 817 |
+
rt.drift_acknowledged = True
|
| 818 |
+
self._grant_step_rebate_once(
|
| 819 |
+
attr="changelog_rewarded_after_drift",
|
| 820 |
+
rebate=STEP_REBATE_READ_CHANGELOG,
|
| 821 |
+
)
|
| 822 |
+
self._mark_diagnostic()
|
| 823 |
+
return ReadChangelogResult(entries=list(rt.changelog_entries))
|
| 824 |
+
|
| 825 |
+
def _handle_submit_rewrite(
|
| 826 |
+
self, payload: SubmitRewritePayload, *, timeout_s: float
|
| 827 |
+
) -> SubmitRewriteResult | ToolError:
|
| 828 |
+
rt = self._require_runtime()
|
| 829 |
+
if not rt.diagnostic_actions_taken:
|
| 830 |
+
return ToolError(
|
| 831 |
+
code=ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE,
|
| 832 |
+
message=(
|
| 833 |
+
"submit_rewrite rejected: the agent must take at least one "
|
| 834 |
+
"diagnostic action (list_tables, describe_table, sample_rows, "
|
| 835 |
+
"run_query, explain_query, or read_changelog) before submitting."
|
| 836 |
+
),
|
| 837 |
+
)
|
| 838 |
+
sql = payload.sql
|
| 839 |
+
try:
|
| 840 |
+
_validate_read_only_sql(sql)
|
| 841 |
+
except ValueError as exc:
|
| 842 |
+
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
|
| 843 |
+
try:
|
| 844 |
+
agent_hash, elapsed_ms = execute_hash_timed(rt.conn, sql, timeout_s=timeout_s)
|
| 845 |
+
except TimeoutError as exc:
|
| 846 |
+
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
|
| 847 |
+
except duckdb.Error as exc:
|
| 848 |
+
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
|
| 849 |
+
gt_hash = (
|
| 850 |
+
rt.gt_result_hash_postdrift
|
| 851 |
+
if rt.drift_fired and rt.gt_result_hash_postdrift is not None
|
| 852 |
+
else rt.gt_result_hash_predrift
|
| 853 |
+
)
|
| 854 |
+
matches = agent_hash == gt_hash
|
| 855 |
+
|
| 856 |
+
rt.submitted = True
|
| 857 |
+
rt.submitted_sql = sql
|
| 858 |
+
rt.submitted_sql_canonical = canonicalize_sql(sql)
|
| 859 |
+
rt.submitted_result_hash = agent_hash
|
| 860 |
+
rt.submitted_runtime_ms = elapsed_ms
|
| 861 |
+
rt.phase = EpisodePhase.FINALIZE
|
| 862 |
+
return SubmitRewriteResult(
|
| 863 |
+
accepted=True,
|
| 864 |
+
runtime_ms=elapsed_ms,
|
| 865 |
+
matches_ground_truth=matches,
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
def _handle_consult_dba(self, payload: ConsultDBAPayload) -> ConsultDBAResult | ToolError:
|
| 869 |
+
rt = self._require_runtime()
|
| 870 |
+
if not rt.dba_oracle_enabled:
|
| 871 |
+
return ToolError(
|
| 872 |
+
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
|
| 873 |
+
message="consult_dba disabled — set enable_dba_oracle=True at reset()",
|
| 874 |
+
)
|
| 875 |
+
if not dba_oracle.has_hints(rt.scenario_id):
|
| 876 |
+
return ToolError(
|
| 877 |
+
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
|
| 878 |
+
message=f"no DBA hints registered for scenario={rt.scenario_id!r}",
|
| 879 |
+
)
|
| 880 |
+
rt.consultations_used += 1
|
| 881 |
+
tier = min(rt.consultations_used, 3)
|
| 882 |
+
hint = dba_oracle.get_hint(rt.scenario_id, tier)
|
| 883 |
+
del payload # question is free-text context only; hints are scenario-keyed.
|
| 884 |
+
return ConsultDBAResult(tier=tier, hint=hint)
|
| 885 |
+
|
| 886 |
+
def _mark_diagnostic(self) -> None:
|
| 887 |
+
"""Record a successful diagnostic tool call and advance the phase machine."""
|
| 888 |
+
rt = self._require_runtime()
|
| 889 |
+
rt.diagnostic_actions_taken += 1
|
| 890 |
+
if rt.phase == EpisodePhase.DIAGNOSE:
|
| 891 |
+
rt.phase = EpisodePhase.REWRITE
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
__all__ = [
|
| 895 |
+
"DEFAULT_STEP_BUDGET",
|
| 896 |
+
"MAX_RESULT_ROWS",
|
| 897 |
+
"QUERY_TIMEOUT_S",
|
| 898 |
+
"SqlDriftEnvironment",
|
| 899 |
+
]
|
skill_library/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-improvement library — pre-seeded playbook + on-disk learned store."""
|
| 2 |
+
|
| 3 |
+
from skill_library.entries import (
|
| 4 |
+
DriftAdaptationCard,
|
| 5 |
+
PlaybookEntry,
|
| 6 |
+
RetrievalResult,
|
| 7 |
+
)
|
| 8 |
+
from skill_library.retrieval import (
|
| 9 |
+
JACCARD_MIN,
|
| 10 |
+
jaccard,
|
| 11 |
+
retrieve,
|
| 12 |
+
top_k_drift_cards,
|
| 13 |
+
top_k_playbook,
|
| 14 |
+
)
|
| 15 |
+
from skill_library.seeds import PRESEED_DRIFT_CARDS, PRESEED_PLAYBOOK
|
| 16 |
+
from skill_library.store import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_all(
|
| 20 |
+
store: Store | None = None,
|
| 21 |
+
) -> tuple[
|
| 22 |
+
tuple[PlaybookEntry, ...],
|
| 23 |
+
tuple[DriftAdaptationCard, ...],
|
| 24 |
+
]:
|
| 25 |
+
"""Union of pre-seeds and any entries persisted on disk.
|
| 26 |
+
|
| 27 |
+
Returns ``(playbook, drift_cards)``. Order: pre-seeds first, then
|
| 28 |
+
learned entries, so deterministic retrieval tie-breaks prefer
|
| 29 |
+
the hand-authored pre-seeds when tags and speedup match exactly.
|
| 30 |
+
"""
|
| 31 |
+
learned_pb: tuple[PlaybookEntry, ...] = ()
|
| 32 |
+
learned_dc: tuple[DriftAdaptationCard, ...] = ()
|
| 33 |
+
if store is not None:
|
| 34 |
+
learned_pb = store.read_playbook()
|
| 35 |
+
learned_dc = store.read_drift_cards()
|
| 36 |
+
return PRESEED_PLAYBOOK + learned_pb, PRESEED_DRIFT_CARDS + learned_dc
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
__all__ = [
|
| 40 |
+
"DEFAULT_STORE_DIR",
|
| 41 |
+
"DriftAdaptationCard",
|
| 42 |
+
"JACCARD_MIN",
|
| 43 |
+
"PRESEED_DRIFT_CARDS",
|
| 44 |
+
"PRESEED_PLAYBOOK",
|
| 45 |
+
"PlaybookEntry",
|
| 46 |
+
"RetrievalResult",
|
| 47 |
+
"Store",
|
| 48 |
+
"cleanup_stale_session_dirs",
|
| 49 |
+
"jaccard",
|
| 50 |
+
"load_all",
|
| 51 |
+
"retrieve",
|
| 52 |
+
"top_k_drift_cards",
|
| 53 |
+
"top_k_playbook",
|
| 54 |
+
]
|
skill_library/entries.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data classes for the self-improvement (skill) library.
|
| 2 |
+
|
| 3 |
+
Two entry kinds — both immutable dataclasses so they can live on
|
| 4 |
+
frozen-dict caches and be safely shared across episodes.
|
| 5 |
+
|
| 6 |
+
- :class:`PlaybookEntry` — a "before/after" SQL rewrite nugget,
|
| 7 |
+
tagged by anti-pattern + scenario family, with an empirical
|
| 8 |
+
speedup number. Populated by pre-seeds and extended at
|
| 9 |
+
terminal-success (``r_correct > 0 ∧ speedup > 1.2``).
|
| 10 |
+
- :class:`DriftAdaptationCard` — a drift-kind recovery card with a
|
| 11 |
+
symptom regex and a recovery template. Pre-seeded 1-per-drift-kind.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Literal
|
| 18 |
+
|
| 19 |
+
DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class PlaybookEntry:
|
| 24 |
+
"""A reusable SQL-rewrite recipe keyed by anti-pattern tags."""
|
| 25 |
+
|
| 26 |
+
tag_set: frozenset[str]
|
| 27 |
+
before_snippet: str
|
| 28 |
+
after_snippet: str
|
| 29 |
+
avg_speedup: float
|
| 30 |
+
scenario_family: str # matches scenarios.base.Family
|
| 31 |
+
# Optional provenance — "preseed" for hand-authored, "learned" for
|
| 32 |
+
# append-on-success entries. Used by the retrieval UI only.
|
| 33 |
+
source: Literal["preseed", "learned"] = "preseed"
|
| 34 |
+
|
| 35 |
+
def render_hint(self, max_chars: int = 200) -> str:
|
| 36 |
+
"""Render a one-liner suitable for inclusion in ``learned_hints``."""
|
| 37 |
+
body = (
|
| 38 |
+
f"[{self.scenario_family}] "
|
| 39 |
+
f"replace `{self.before_snippet[:60]}...` with "
|
| 40 |
+
f"`{self.after_snippet[:60]}...` "
|
| 41 |
+
f"(~{self.avg_speedup:.1f}x)"
|
| 42 |
+
)
|
| 43 |
+
return body[:max_chars]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True)
|
| 47 |
+
class DriftAdaptationCard:
|
| 48 |
+
"""A drift-kind recovery card."""
|
| 49 |
+
|
| 50 |
+
drift_kind: DriftKind
|
| 51 |
+
symptom_regex: str
|
| 52 |
+
recovery_template: str
|
| 53 |
+
success_rate: float = 0.0
|
| 54 |
+
source: Literal["preseed", "learned"] = "preseed"
|
| 55 |
+
|
| 56 |
+
def render_hint(self, max_chars: int = 200) -> str:
|
| 57 |
+
body = (
|
| 58 |
+
f"[drift:{self.drift_kind}] "
|
| 59 |
+
f"symptom=/{self.symptom_regex}/ → "
|
| 60 |
+
f"{self.recovery_template[:120]}"
|
| 61 |
+
)
|
| 62 |
+
return body[:max_chars]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass(frozen=True)
|
| 66 |
+
class RetrievalResult:
|
| 67 |
+
"""Top-k blend of playbook hits + drift cards for one retrieval call."""
|
| 68 |
+
|
| 69 |
+
playbook: tuple[PlaybookEntry, ...] = field(default_factory=tuple)
|
| 70 |
+
drift_cards: tuple[DriftAdaptationCard, ...] = field(default_factory=tuple)
|
| 71 |
+
|
| 72 |
+
def render(self, max_chars: int = 800) -> str:
|
| 73 |
+
"""Concatenate rendered hints, truncated to ``max_chars``.
|
| 74 |
+
|
| 75 |
+
Deterministic ordering: playbook entries first (by descending
|
| 76 |
+
``avg_speedup``, ties broken by ``before_snippet``), then drift
|
| 77 |
+
cards (by descending ``success_rate``, ties by ``drift_kind``).
|
| 78 |
+
"""
|
| 79 |
+
lines: list[str] = []
|
| 80 |
+
for e in self.playbook:
|
| 81 |
+
lines.append("- " + e.render_hint(max_chars=200))
|
| 82 |
+
for c in self.drift_cards:
|
| 83 |
+
lines.append("- " + c.render_hint(max_chars=200))
|
| 84 |
+
out = "\n".join(lines)
|
| 85 |
+
return out[:max_chars]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
__all__ = [
|
| 89 |
+
"DriftAdaptationCard",
|
| 90 |
+
"DriftKind",
|
| 91 |
+
"PlaybookEntry",
|
| 92 |
+
"RetrievalResult",
|
| 93 |
+
]
|
skill_library/retrieval.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pure-function Jaccard top-k retrieval over tag sets.
|
| 2 |
+
|
| 3 |
+
No embeddings, no torch, no tokenizer. Deterministic — same inputs
|
| 4 |
+
produce the same ranking, same top-k, same tie-break.
|
| 5 |
+
|
| 6 |
+
A conservative Jaccard threshold (0.3) limits retrieval noise when
|
| 7 |
+
broad pre-seeds would otherwise match every scenario.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from collections.abc import Iterable
|
| 13 |
+
|
| 14 |
+
from skill_library.entries import (
|
| 15 |
+
DriftAdaptationCard,
|
| 16 |
+
PlaybookEntry,
|
| 17 |
+
RetrievalResult,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
JACCARD_MIN: float = 0.3
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def jaccard(a: frozenset[str], b: frozenset[str]) -> float:
|
| 24 |
+
"""Standard Jaccard on sets."""
|
| 25 |
+
if not a and not b:
|
| 26 |
+
return 1.0
|
| 27 |
+
if not a or not b:
|
| 28 |
+
return 0.0
|
| 29 |
+
return len(a & b) / len(a | b)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def top_k_playbook(
|
| 33 |
+
query_tags: frozenset[str],
|
| 34 |
+
entries: Iterable[PlaybookEntry],
|
| 35 |
+
k: int = 3,
|
| 36 |
+
*,
|
| 37 |
+
min_overlap: float = JACCARD_MIN,
|
| 38 |
+
) -> tuple[PlaybookEntry, ...]:
|
| 39 |
+
"""Top-k playbook entries by Jaccard overlap with ``query_tags``.
|
| 40 |
+
|
| 41 |
+
Ties broken by descending ``avg_speedup`` then by ``before_snippet``
|
| 42 |
+
lexicographic order so the result is stable across runs.
|
| 43 |
+
"""
|
| 44 |
+
scored = [(jaccard(query_tags, e.tag_set), e) for e in entries]
|
| 45 |
+
scored = [(j, e) for j, e in scored if j >= min_overlap]
|
| 46 |
+
scored.sort(key=lambda t: (-t[0], -t[1].avg_speedup, t[1].before_snippet))
|
| 47 |
+
return tuple(e for _, e in scored[:k])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def top_k_drift_cards(
|
| 51 |
+
drift_kind: str | None,
|
| 52 |
+
cards: Iterable[DriftAdaptationCard],
|
| 53 |
+
k: int = 1,
|
| 54 |
+
) -> tuple[DriftAdaptationCard, ...]:
|
| 55 |
+
"""Filter cards by exact drift_kind match, sorted by success_rate desc."""
|
| 56 |
+
if drift_kind is None:
|
| 57 |
+
return ()
|
| 58 |
+
matches = [c for c in cards if c.drift_kind == drift_kind]
|
| 59 |
+
matches.sort(key=lambda c: (-c.success_rate, c.drift_kind))
|
| 60 |
+
return tuple(matches[:k])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def retrieve(
|
| 64 |
+
query_tags: frozenset[str],
|
| 65 |
+
drift_kind: str | None,
|
| 66 |
+
playbook: Iterable[PlaybookEntry],
|
| 67 |
+
drift_cards: Iterable[DriftAdaptationCard],
|
| 68 |
+
*,
|
| 69 |
+
playbook_k: int = 3,
|
| 70 |
+
drift_k: int = 1,
|
| 71 |
+
) -> RetrievalResult:
|
| 72 |
+
"""Combined retrieval: top-k playbook + top-k drift cards."""
|
| 73 |
+
return RetrievalResult(
|
| 74 |
+
playbook=top_k_playbook(query_tags, playbook, k=playbook_k),
|
| 75 |
+
drift_cards=top_k_drift_cards(drift_kind, drift_cards, k=drift_k),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
__all__ = [
|
| 80 |
+
"JACCARD_MIN",
|
| 81 |
+
"jaccard",
|
| 82 |
+
"retrieve",
|
| 83 |
+
"top_k_drift_cards",
|
| 84 |
+
"top_k_playbook",
|
| 85 |
+
]
|
skill_library/seeds.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hand-authored pre-seed library — 8 playbook entries + 4 drift cards.
|
| 2 |
+
|
| 3 |
+
Tag sets here match the tags on each corresponding scenario
|
| 4 |
+
so a Jaccard top-k against the scenario's tag set returns
|
| 5 |
+
a relevant entry for every one of the 10 scenarios — with two
|
| 6 |
+
generic fallbacks for coverage on novel anti-patterns.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from skill_library.entries import DriftAdaptationCard, PlaybookEntry
|
| 12 |
+
|
| 13 |
+
PRESEED_PLAYBOOK: tuple[PlaybookEntry, ...] = (
|
| 14 |
+
# 1. Correlated subquery → LEFT JOIN + GROUP BY
|
| 15 |
+
PlaybookEntry(
|
| 16 |
+
tag_set=frozenset({"correlated_subquery", "projection_subquery"}),
|
| 17 |
+
before_snippet=(
|
| 18 |
+
"SELECT u.*, (SELECT COUNT(*) FROM orders o WHERE o.user_id=u.id) FROM users u"
|
| 19 |
+
),
|
| 20 |
+
after_snippet=(
|
| 21 |
+
"SELECT u.*, COALESCE(c.n, 0) FROM users u "
|
| 22 |
+
"LEFT JOIN (SELECT user_id, COUNT(*) n FROM orders GROUP BY user_id) c "
|
| 23 |
+
"ON c.user_id = u.id"
|
| 24 |
+
),
|
| 25 |
+
avg_speedup=6.0,
|
| 26 |
+
scenario_family="ecommerce",
|
| 27 |
+
),
|
| 28 |
+
# 2. SELECT * + join → project only needed columns
|
| 29 |
+
PlaybookEntry(
|
| 30 |
+
tag_set=frozenset({"select_star", "over_projection", "join"}),
|
| 31 |
+
before_snippet="SELECT * FROM a JOIN b ON a.id=b.a_id",
|
| 32 |
+
after_snippet="SELECT a.id, a.name, b.amount FROM a JOIN b ON a.id=b.a_id",
|
| 33 |
+
avg_speedup=2.5,
|
| 34 |
+
scenario_family="ecommerce",
|
| 35 |
+
),
|
| 36 |
+
# 3. Cartesian join — add explicit ON clause
|
| 37 |
+
PlaybookEntry(
|
| 38 |
+
tag_set=frozenset({"cartesian", "missing_join_condition"}),
|
| 39 |
+
before_snippet="SELECT * FROM a, b WHERE a.region = 'US'",
|
| 40 |
+
after_snippet="SELECT a.col FROM a JOIN b ON a.id = b.a_id WHERE a.region = 'US'",
|
| 41 |
+
avg_speedup=50.0,
|
| 42 |
+
scenario_family="events",
|
| 43 |
+
),
|
| 44 |
+
# 4. DISTINCT on GROUP BY — drop one
|
| 45 |
+
PlaybookEntry(
|
| 46 |
+
tag_set=frozenset({"distinct", "redundant_distinct", "group_by"}),
|
| 47 |
+
before_snippet="SELECT DISTINCT tenant_id, count(*) FROM logs GROUP BY tenant_id",
|
| 48 |
+
after_snippet="SELECT tenant_id, count(*) FROM logs GROUP BY tenant_id",
|
| 49 |
+
avg_speedup=1.4,
|
| 50 |
+
scenario_family="saas_logs",
|
| 51 |
+
),
|
| 52 |
+
# 5. Nested IN-subquery → JOIN
|
| 53 |
+
PlaybookEntry(
|
| 54 |
+
tag_set=frozenset({"nested_subquery", "in_subquery"}),
|
| 55 |
+
before_snippet="WHERE id IN (SELECT x_id FROM x WHERE ... )",
|
| 56 |
+
after_snippet="JOIN x ON x.x_id = table.id WHERE ...",
|
| 57 |
+
avg_speedup=3.0,
|
| 58 |
+
scenario_family="cms",
|
| 59 |
+
),
|
| 60 |
+
# 6. HAVING filter on groupable column → push to WHERE
|
| 61 |
+
PlaybookEntry(
|
| 62 |
+
tag_set=frozenset({"having_as_where", "aggregate_filter"}),
|
| 63 |
+
before_snippet="GROUP BY x, status HAVING status = 'fulfilled'",
|
| 64 |
+
after_snippet="WHERE status = 'fulfilled' GROUP BY x",
|
| 65 |
+
avg_speedup=2.0,
|
| 66 |
+
scenario_family="ecommerce",
|
| 67 |
+
),
|
| 68 |
+
# 7. Generic: prefer JOINs over correlated subqueries
|
| 69 |
+
PlaybookEntry(
|
| 70 |
+
tag_set=frozenset({"subquery", "generic"}),
|
| 71 |
+
before_snippet="scalar subquery in SELECT list",
|
| 72 |
+
after_snippet="LEFT JOIN with aggregated CTE",
|
| 73 |
+
avg_speedup=4.0,
|
| 74 |
+
scenario_family="ecommerce",
|
| 75 |
+
),
|
| 76 |
+
# 8. Generic: project only used columns
|
| 77 |
+
PlaybookEntry(
|
| 78 |
+
tag_set=frozenset({"over_projection", "generic"}),
|
| 79 |
+
before_snippet="SELECT *",
|
| 80 |
+
after_snippet="SELECT <only needed columns>",
|
| 81 |
+
avg_speedup=1.8,
|
| 82 |
+
scenario_family="ecommerce",
|
| 83 |
+
),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
PRESEED_DRIFT_CARDS: tuple[DriftAdaptationCard, ...] = (
|
| 88 |
+
DriftAdaptationCard(
|
| 89 |
+
drift_kind="column_rename",
|
| 90 |
+
symptom_regex=r'column ".+" does not exist',
|
| 91 |
+
recovery_template=(
|
| 92 |
+
"Read the changelog, update every identifier referencing the old column, and resubmit."
|
| 93 |
+
),
|
| 94 |
+
success_rate=0.9,
|
| 95 |
+
),
|
| 96 |
+
DriftAdaptationCard(
|
| 97 |
+
drift_kind="date_format",
|
| 98 |
+
symptom_regex=r"Could not convert string .+ to TIMESTAMP|BIGINT",
|
| 99 |
+
recovery_template=(
|
| 100 |
+
"Epoch-ms columns are BIGINT; cast your filter bounds with "
|
| 101 |
+
"`EXTRACT(EPOCH FROM TIMESTAMP '...') * 1000` or use numeric literals."
|
| 102 |
+
),
|
| 103 |
+
success_rate=0.85,
|
| 104 |
+
),
|
| 105 |
+
DriftAdaptationCard(
|
| 106 |
+
drift_kind="enum_rule",
|
| 107 |
+
symptom_regex=r"(empty|zero) result set on filter `... = 'active'`",
|
| 108 |
+
recovery_template=(
|
| 109 |
+
"A single enum value may have been split into several; use `IN "
|
| 110 |
+
"('ACTIVE', 'ACTIVE_V2')` instead of equality."
|
| 111 |
+
),
|
| 112 |
+
success_rate=0.8,
|
| 113 |
+
),
|
| 114 |
+
DriftAdaptationCard(
|
| 115 |
+
drift_kind="field_deprecation",
|
| 116 |
+
symptom_regex=r'column ".+" does not exist|non-existent column',
|
| 117 |
+
recovery_template=(
|
| 118 |
+
"The inline field was replaced by a FK; JOIN the lookup table and "
|
| 119 |
+
"project the human-readable name from there."
|
| 120 |
+
),
|
| 121 |
+
success_rate=0.75,
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
__all__ = ["PRESEED_DRIFT_CARDS", "PRESEED_PLAYBOOK"]
|
skill_library/store.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""JSON-backed playbook and drift-card store with file locking.
|
| 2 |
+
|
| 3 |
+
Each store file is a JSON array. Appends go through a single locked
|
| 4 |
+
read-modify-write cycle that writes to a sibling ``*.tmp`` file and
|
| 5 |
+
atomically ``os.replace``s it onto the target path, so a crash can only
|
| 6 |
+
leave either the old array or the new one — never a truncated file.
|
| 7 |
+
|
| 8 |
+
The lock is held on a dedicated ``*.lock`` file via ``fcntl.flock`` with
|
| 9 |
+
a caller-configurable timeout (default 5s). We never lock the data file
|
| 10 |
+
itself: that way an ``os.replace`` inside the critical section can't
|
| 11 |
+
race against a reader holding a shared lock on the old inode.
|
| 12 |
+
|
| 13 |
+
Reads are cached by mtime so hot-path episodes don't re-parse the file
|
| 14 |
+
on every ``reset()``. Corrupt trailers (from a pre-atomic-write era or
|
| 15 |
+
a partial disk write) log a warning and fall back to empty — we prefer
|
| 16 |
+
a running trainer over one that dies because of a bad card.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import contextlib
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import time
|
| 25 |
+
from collections.abc import Callable, Iterator
|
| 26 |
+
from dataclasses import asdict
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Any, Literal, TypeVar, cast
|
| 29 |
+
|
| 30 |
+
from skill_library.entries import DriftAdaptationCard, PlaybookEntry
|
| 31 |
+
from utilities.logger import get_module_logger
|
| 32 |
+
|
| 33 |
+
_LOG = get_module_logger(__name__)
|
| 34 |
+
|
| 35 |
+
DEFAULT_STORE_DIR = Path("outputs") / "skill_library"
|
| 36 |
+
PLAYBOOK_FILENAME = "playbook.json"
|
| 37 |
+
DRIFT_CARDS_FILENAME = "drift_cards.json"
|
| 38 |
+
DEFAULT_LOCK_TIMEOUT_S: float = 5.0
|
| 39 |
+
|
| 40 |
+
T = TypeVar("T")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import fcntl
|
| 45 |
+
|
| 46 |
+
def _try_lock_exclusive(fh: Any) -> bool:
|
| 47 |
+
try:
|
| 48 |
+
fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
| 49 |
+
return True
|
| 50 |
+
except BlockingIOError:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def _unlock(fh: Any) -> None:
|
| 54 |
+
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
|
| 55 |
+
|
| 56 |
+
_HAS_FCNTL = True
|
| 57 |
+
|
| 58 |
+
except ImportError:
|
| 59 |
+
_HAS_FCNTL = False
|
| 60 |
+
|
| 61 |
+
def _try_lock_exclusive(fh: Any) -> bool:
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
def _unlock(fh: Any) -> None:
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@contextlib.contextmanager
|
| 69 |
+
def _locked(path: Path, timeout_s: float) -> Iterator[None]:
|
| 70 |
+
"""Poll-acquire an exclusive flock on ``path`` within ``timeout_s``."""
|
| 71 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
deadline = time.monotonic() + timeout_s
|
| 73 |
+
with path.open("a+") as fh:
|
| 74 |
+
while not _try_lock_exclusive(fh):
|
| 75 |
+
if time.monotonic() >= deadline:
|
| 76 |
+
raise TimeoutError(f"could not acquire {path} within {timeout_s}s")
|
| 77 |
+
time.sleep(0.02)
|
| 78 |
+
try:
|
| 79 |
+
yield
|
| 80 |
+
finally:
|
| 81 |
+
if _HAS_FCNTL:
|
| 82 |
+
_unlock(fh)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _atomic_write_json(path: Path, payload: list[dict[str, Any]]) -> None:
|
| 86 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 87 |
+
text = json.dumps(payload, indent=2)
|
| 88 |
+
with tmp.open("w", encoding="utf-8") as fh:
|
| 89 |
+
fh.write(text)
|
| 90 |
+
fh.flush()
|
| 91 |
+
os.fsync(fh.fileno())
|
| 92 |
+
os.replace(tmp, path)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _read_json_array(path: Path) -> list[Any]:
|
| 96 |
+
if not path.exists():
|
| 97 |
+
return []
|
| 98 |
+
try:
|
| 99 |
+
raw = path.read_text(encoding="utf-8")
|
| 100 |
+
except OSError as exc:
|
| 101 |
+
_LOG.warning("skill-store read failed for %s: %s", path, exc)
|
| 102 |
+
return []
|
| 103 |
+
if not raw.strip():
|
| 104 |
+
return []
|
| 105 |
+
try:
|
| 106 |
+
data = json.loads(raw)
|
| 107 |
+
except json.JSONDecodeError as exc:
|
| 108 |
+
_LOG.warning("skill-store corrupt at %s (%s); returning empty", path, exc)
|
| 109 |
+
return []
|
| 110 |
+
return data if isinstance(data, list) else []
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Store:
|
| 114 |
+
"""Append-only JSON store for learned playbook entries + drift cards."""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
directory: Path | None = None,
|
| 119 |
+
lock_timeout_s: float = DEFAULT_LOCK_TIMEOUT_S,
|
| 120 |
+
) -> None:
|
| 121 |
+
self.dir = Path(directory) if directory is not None else DEFAULT_STORE_DIR
|
| 122 |
+
self.lock_timeout_s = lock_timeout_s
|
| 123 |
+
self._playbook_cache: tuple[tuple[PlaybookEntry, ...], float] | None = None
|
| 124 |
+
self._drift_cache: tuple[tuple[DriftAdaptationCard, ...], float] | None = None
|
| 125 |
+
|
| 126 |
+
def playbook_path(self) -> Path:
|
| 127 |
+
return self.dir / PLAYBOOK_FILENAME
|
| 128 |
+
|
| 129 |
+
def drift_cards_path(self) -> Path:
|
| 130 |
+
return self.dir / DRIFT_CARDS_FILENAME
|
| 131 |
+
|
| 132 |
+
def read_playbook(self) -> tuple[PlaybookEntry, ...]:
|
| 133 |
+
return self._read_cached(
|
| 134 |
+
self.playbook_path(),
|
| 135 |
+
cache_attr="_playbook_cache",
|
| 136 |
+
decode=_entry_from_dict,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def read_drift_cards(self) -> tuple[DriftAdaptationCard, ...]:
|
| 140 |
+
return self._read_cached(
|
| 141 |
+
self.drift_cards_path(),
|
| 142 |
+
cache_attr="_drift_cache",
|
| 143 |
+
decode=lambda d: DriftAdaptationCard(**d),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def append_playbook(self, entry: PlaybookEntry) -> None:
|
| 147 |
+
self._locked_append(
|
| 148 |
+
self.playbook_path(),
|
| 149 |
+
encode_new=_entry_to_dict,
|
| 150 |
+
new_item=entry,
|
| 151 |
+
)
|
| 152 |
+
self._playbook_cache = None
|
| 153 |
+
|
| 154 |
+
def append_drift_card(self, card: DriftAdaptationCard) -> None:
|
| 155 |
+
self._locked_append(
|
| 156 |
+
self.drift_cards_path(),
|
| 157 |
+
encode_new=asdict,
|
| 158 |
+
new_item=card,
|
| 159 |
+
)
|
| 160 |
+
self._drift_cache = None
|
| 161 |
+
|
| 162 |
+
def _read_cached(
|
| 163 |
+
self,
|
| 164 |
+
path: Path,
|
| 165 |
+
*,
|
| 166 |
+
cache_attr: str,
|
| 167 |
+
decode: Callable[[dict[str, Any]], T],
|
| 168 |
+
) -> tuple[T, ...]:
|
| 169 |
+
mtime = _safe_mtime(path)
|
| 170 |
+
# ``getattr``/``setattr`` is intentional — the same implementation
|
| 171 |
+
# services both the playbook and drift-card caches, whose Python
|
| 172 |
+
# types differ. The cast below restores the precise
|
| 173 |
+
# ``(tuple[T, ...], float) | None`` shape for mypy.
|
| 174 |
+
cached = cast("tuple[tuple[T, ...], float] | None", getattr(self, cache_attr))
|
| 175 |
+
if cached is not None and cached[1] == mtime:
|
| 176 |
+
return cached[0]
|
| 177 |
+
items: list[T] = []
|
| 178 |
+
for d in _read_json_array(path):
|
| 179 |
+
try:
|
| 180 |
+
items.append(decode(d))
|
| 181 |
+
except (TypeError, KeyError, ValueError) as exc:
|
| 182 |
+
_LOG.warning("skipping malformed store entry %s: %s", d, exc)
|
| 183 |
+
tup = tuple(items)
|
| 184 |
+
setattr(self, cache_attr, (tup, mtime))
|
| 185 |
+
return tup
|
| 186 |
+
|
| 187 |
+
def _locked_append(
|
| 188 |
+
self,
|
| 189 |
+
path: Path,
|
| 190 |
+
*,
|
| 191 |
+
encode_new: Callable[[Any], dict[str, Any]],
|
| 192 |
+
new_item: Any,
|
| 193 |
+
) -> None:
|
| 194 |
+
self.dir.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
lock_path = path.with_suffix(path.suffix + ".lock")
|
| 196 |
+
with _locked(lock_path, self.lock_timeout_s):
|
| 197 |
+
existing = _read_json_array(path)
|
| 198 |
+
existing.append(encode_new(new_item))
|
| 199 |
+
_atomic_write_json(path, existing)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _safe_mtime(path: Path) -> float:
|
| 203 |
+
try:
|
| 204 |
+
return path.stat().st_mtime
|
| 205 |
+
except FileNotFoundError:
|
| 206 |
+
return 0.0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _entry_to_dict(e: PlaybookEntry) -> dict[str, Any]:
|
| 210 |
+
return {
|
| 211 |
+
"tag_set": sorted(e.tag_set),
|
| 212 |
+
"before_snippet": e.before_snippet,
|
| 213 |
+
"after_snippet": e.after_snippet,
|
| 214 |
+
"avg_speedup": e.avg_speedup,
|
| 215 |
+
"scenario_family": e.scenario_family,
|
| 216 |
+
"source": e.source,
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _entry_from_dict(d: dict[str, Any]) -> PlaybookEntry:
|
| 221 |
+
source: Literal["preseed", "learned"] = d.get("source", "learned")
|
| 222 |
+
return PlaybookEntry(
|
| 223 |
+
tag_set=frozenset(d.get("tag_set") or []),
|
| 224 |
+
before_snippet=d["before_snippet"],
|
| 225 |
+
after_snippet=d["after_snippet"],
|
| 226 |
+
avg_speedup=float(d["avg_speedup"]),
|
| 227 |
+
scenario_family=d["scenario_family"],
|
| 228 |
+
source=source,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def cleanup_stale_session_dirs(root: Path, ttl_hours: float) -> int:
|
| 233 |
+
"""Remove session subdirectories under *root* whose mtime is older than *ttl_hours*.
|
| 234 |
+
|
| 235 |
+
Returns the number of directories removed. Errors on individual
|
| 236 |
+
subdirectories are logged and skipped so a single bad entry cannot abort
|
| 237 |
+
the sweep. Pass ``ttl_hours=0`` to disable (returns 0 immediately).
|
| 238 |
+
"""
|
| 239 |
+
import shutil
|
| 240 |
+
|
| 241 |
+
if ttl_hours <= 0 or not root.exists():
|
| 242 |
+
return 0
|
| 243 |
+
cutoff = time.time() - ttl_hours * 3600
|
| 244 |
+
removed = 0
|
| 245 |
+
for session_dir in root.iterdir():
|
| 246 |
+
if not session_dir.is_dir():
|
| 247 |
+
continue
|
| 248 |
+
try:
|
| 249 |
+
if session_dir.stat().st_mtime < cutoff:
|
| 250 |
+
shutil.rmtree(session_dir, ignore_errors=True)
|
| 251 |
+
removed += 1
|
| 252 |
+
except OSError as exc:
|
| 253 |
+
_LOG.warning("cleanup_stale_session_dirs: skipping %s: %s", session_dir, exc)
|
| 254 |
+
return removed
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
__all__ = ["DEFAULT_LOCK_TIMEOUT_S", "DEFAULT_STORE_DIR", "Store", "cleanup_stale_session_dirs"]
|
sql_drift_grpo_training.ipynb
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# SQLDrift GRPO Training Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Runnable hackathon training artifact for SQLDrift. This notebook uses the repository's existing `training.grpo_train` entrypoint, which wires Unsloth, Hugging Face TRL `GRPOTrainer`, Transformers, PEFT LoRA, and the SQLDrift OpenEnv tool environment.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"Assumptions:\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"- You run this on a GPU runtime, preferably Colab with an NVIDIA T4/A10/L4 or better.\n",
|
| 14 |
+
"- `SQL_DRIFT_ENV_URL` points at the deployed SQLDrift HF Space, for example `https://your-space.hf.space`.\n",
|
| 15 |
+
"- The notebook is run from the repository root, or `SQL_DRIFT_REPO_URL` is set so Colab can clone the repo.\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"Success criteria:\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"1. The deployed environment health check succeeds.\n",
|
| 20 |
+
"2. GRPO trains for the configured number of steps against live SQLDrift rollouts.\n",
|
| 21 |
+
"3. A LoRA adapter/checkpoint is written under `outputs/checkpoints/grpo_hackathon/`.\n",
|
| 22 |
+
"4. Durable evidence files are written under `training/evidence/`: `grpo_metrics.csv`, `grpo_loss_curve.png`, and `grpo_reward_curve.png`.\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"After pushing this notebook to GitHub, open it in Colab by using:\n",
|
| 25 |
+
"`https://colab.research.google.com/github/<YOUR_GITHUB_USER>/<YOUR_REPO>/blob/main/training/sql_drift_grpo_training.ipynb`\n"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": [
|
| 32 |
+
"## 1. Install The Training Stack\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"If this notebook is launched directly in Colab, set `SQL_DRIFT_REPO_URL` first. If it is already running from the repository root, the clone step is skipped.\n"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"from pathlib import Path\n",
|
| 44 |
+
"import os\n",
|
| 45 |
+
"import subprocess\n",
|
| 46 |
+
"import sys\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"REPO_URL = os.environ.get(\"SQL_DRIFT_REPO_URL\", \"\").strip()\n",
|
| 49 |
+
"repo_root = Path.cwd()\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"if not (repo_root / \"pyproject.toml\").exists():\n",
|
| 52 |
+
" if not REPO_URL:\n",
|
| 53 |
+
" raise RuntimeError(\n",
|
| 54 |
+
" \"Set SQL_DRIFT_REPO_URL to your GitHub repo URL, or run this notebook \"\n",
|
| 55 |
+
" \"from the sql_drift_env repository root.\"\n",
|
| 56 |
+
" )\n",
|
| 57 |
+
" subprocess.check_call([\"git\", \"clone\", REPO_URL, \"sql_drift_env\"])\n",
|
| 58 |
+
" os.chdir(\"sql_drift_env\")\n",
|
| 59 |
+
" repo_root = Path.cwd()\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"subprocess.check_call(\n",
|
| 62 |
+
" [\n",
|
| 63 |
+
" sys.executable,\n",
|
| 64 |
+
" \"-m\",\n",
|
| 65 |
+
" \"pip\",\n",
|
| 66 |
+
" \"install\",\n",
|
| 67 |
+
" \"-q\",\n",
|
| 68 |
+
" \"-e\",\n",
|
| 69 |
+
" \".[train]\",\n",
|
| 70 |
+
" \"matplotlib\",\n",
|
| 71 |
+
" \"pandas\",\n",
|
| 72 |
+
" \"peft\",\n",
|
| 73 |
+
" \"bitsandbytes\",\n",
|
| 74 |
+
" ]\n",
|
| 75 |
+
")\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"print(f\"Repo root: {repo_root}\")"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "markdown",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"source": [
|
| 84 |
+
"## 2. Point At The Deployed SQLDrift Environment\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"Set `SQL_DRIFT_ENV_URL` to the HF Space URL before running the cell, for example:\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"```python\n",
|
| 89 |
+
"import os\n",
|
| 90 |
+
"os.environ[\"SQL_DRIFT_ENV_URL\"] = \"https://your-space.hf.space\"\n",
|
| 91 |
+
"```\n"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": null,
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"import json\n",
|
| 101 |
+
"import urllib.request\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"ENV_BASE_URL = os.environ.get(\"SQL_DRIFT_ENV_URL\", \"\").strip()\n",
|
| 104 |
+
"if not ENV_BASE_URL:\n",
|
| 105 |
+
" raise RuntimeError(\"Set SQL_DRIFT_ENV_URL to the deployed SQLDrift HF Space URL.\")\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"health_url = ENV_BASE_URL.rstrip(\"/\") + \"/health\"\n",
|
| 108 |
+
"with urllib.request.urlopen(health_url, timeout=30) as response:\n",
|
| 109 |
+
" body = response.read().decode(\"utf-8\")\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"print(f\"Health check OK: {health_url}\")\n",
|
| 112 |
+
"print(body[:500])"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "markdown",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"source": [
|
| 119 |
+
"## 3. Configure A Minimal GRPO Run\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"Defaults are intentionally small enough for a hackathon Colab run while still producing non-trivial curves. Increase `SQL_DRIFT_GRPO_MAX_STEPS` for final evidence.\n"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"from dataclasses import asdict\n",
|
| 131 |
+
"from pathlib import Path\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"MAX_STEPS = int(os.environ.get(\"SQL_DRIFT_GRPO_MAX_STEPS\", \"80\"))\n",
|
| 136 |
+
"GROUP_SIZE = int(os.environ.get(\"SQL_DRIFT_GRPO_GROUP_SIZE\", \"4\"))\n",
|
| 137 |
+
"SEED = int(os.environ.get(\"SQL_DRIFT_GRPO_SEED\", \"7\"))\n",
|
| 138 |
+
"OUTPUT_DIR = \"outputs/checkpoints/grpo_hackathon\"\n",
|
| 139 |
+
"EVIDENCE_DIR = Path(\"training/evidence\")\n",
|
| 140 |
+
"EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"# P0 audit fixes (design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md \u00a72.6):\n",
|
| 143 |
+
"# - max_seq_length 4096 \u2192 2048 (T4 16 GB OOM safety)\n",
|
| 144 |
+
"# - max_completion_length 1024 \u2192 256 (JSON tool envelopes need ~200 tok)\n",
|
| 145 |
+
"# - max_prompt_length kept large at 1536 (system prompt + tool history)\n",
|
| 146 |
+
"# - save_steps max(20, MAX_STEPS // 4) so a step-79 crash still leaves a checkpoint\n",
|
| 147 |
+
"# - default model swap to unsloth/Qwen3-1.7B-unsloth-bnb-4bit per audit \u00a74\n",
|
| 148 |
+
"cfg = GRPOConfig(\n",
|
| 149 |
+
" model_name=os.environ.get(\"SQL_DRIFT_MODEL_NAME\", \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\"),\n",
|
| 150 |
+
" env_base_url=ENV_BASE_URL,\n",
|
| 151 |
+
" output_dir=OUTPUT_DIR,\n",
|
| 152 |
+
" max_steps=MAX_STEPS,\n",
|
| 153 |
+
" group_size=GROUP_SIZE,\n",
|
| 154 |
+
" gradient_accumulation_steps=int(os.environ.get(\"SQL_DRIFT_GRAD_ACCUM\", \"1\")),\n",
|
| 155 |
+
" learning_rate=float(os.environ.get(\"SQL_DRIFT_LR\", \"5e-6\")),\n",
|
| 156 |
+
" warmup_steps=min(10, max(1, MAX_STEPS // 10)),\n",
|
| 157 |
+
" save_steps=max(20, MAX_STEPS // 4),\n",
|
| 158 |
+
" logging_steps=1,\n",
|
| 159 |
+
" seed=SEED,\n",
|
| 160 |
+
" max_seq_length=2048,\n",
|
| 161 |
+
" max_prompt_length=1536,\n",
|
| 162 |
+
" max_completion_length=256,\n",
|
| 163 |
+
" curriculum=CurriculumConfig(scenarios=ALL_SCENARIOS, mode=\"weighted\", weights=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2)),\n",
|
| 164 |
+
")\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"(EVIDENCE_DIR / \"grpo_config.json\").write_text(json.dumps(asdict(cfg), indent=2))\n",
|
| 167 |
+
"print(json.dumps(asdict(cfg), indent=2))"
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"cell_type": "markdown",
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"source": [
|
| 174 |
+
"## 4. Run GRPO Training\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"This calls `training.grpo_train.train(cfg)`, which loads the Unsloth model, attaches LoRA adapters, builds the curriculum dataset, and lets TRL drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`.\n"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "markdown",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"source": [
|
| 183 |
+
"### 4a. Free CPU smoke (max_steps=2)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"Pre-flight: validate env URL, dataset build, callback registration, and log file plumbing\n",
|
| 186 |
+
"**before** spending GPU time. P0 fix per `design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md` \u00a75.3.\n",
|
| 187 |
+
"The assertion at the end aborts the notebook if the smoke produced no records."
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "code",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"execution_count": null,
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"from dataclasses import replace\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"smoke_cfg = replace(\n",
|
| 199 |
+
" cfg,\n",
|
| 200 |
+
" max_steps=2,\n",
|
| 201 |
+
" save_steps=2,\n",
|
| 202 |
+
" output_dir=str(EVIDENCE_DIR / \"smoke\"),\n",
|
| 203 |
+
")\n",
|
| 204 |
+
"print(f\"\\n=== SMOKE TEST: {smoke_cfg.max_steps} steps, output={smoke_cfg.output_dir} ===\\n\")\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"from training.grpo_train import train as _train\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"_smoke_trainer = _train(smoke_cfg)\n",
|
| 209 |
+
"_smoke_history = list(getattr(_smoke_trainer.state, \"log_history\", []))\n",
|
| 210 |
+
"assert _smoke_history, \"Smoke produced no log records \u2014 abort before the paid run.\"\n",
|
| 211 |
+
"print(f\"Smoke OK \u2014 {len(_smoke_history)} log records. Proceed to the paid run below.\")"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"from training.grpo_train import train\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"trainer = train(cfg)\n",
|
| 223 |
+
"history = list(getattr(trainer.state, \"log_history\", []))\n",
|
| 224 |
+
"print(f\"Captured {len(history)} trainer log records\")\n",
|
| 225 |
+
"print(f\"Checkpoint directory: {OUTPUT_DIR}\")"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "markdown",
|
| 230 |
+
"metadata": {},
|
| 231 |
+
"source": [
|
| 232 |
+
"## 5. Save Evidence Curves\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"The evaluator requires committed image files, not only WandB links or inline notebook plots. This cell writes PNGs and the raw metrics CSV into `training/evidence/`.\n"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": null,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"outputs": [],
|
| 242 |
+
"source": [
|
| 243 |
+
"import pandas as pd\n",
|
| 244 |
+
"import matplotlib.pyplot as plt\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"df = pd.DataFrame(history)\n",
|
| 247 |
+
"if df.empty:\n",
|
| 248 |
+
" raise RuntimeError(\"Trainer log history is empty; training did not emit metrics.\")\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"if \"step\" not in df.columns:\n",
|
| 251 |
+
" df[\"step\"] = range(len(df))\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"metrics_csv = EVIDENCE_DIR / \"grpo_metrics.csv\"\n",
|
| 254 |
+
"df.to_csv(metrics_csv, index=False)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"def _numeric_columns(frame: pd.DataFrame) -> dict[str, pd.Series]:\n",
|
| 258 |
+
" out: dict[str, pd.Series] = {}\n",
|
| 259 |
+
" for column in frame.columns:\n",
|
| 260 |
+
" series = pd.to_numeric(frame[column], errors=\"coerce\")\n",
|
| 261 |
+
" if series.notna().any():\n",
|
| 262 |
+
" out[column] = series\n",
|
| 263 |
+
" return out\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"def _choose_metric(\n",
|
| 267 |
+
" frame: pd.DataFrame,\n",
|
| 268 |
+
" *,\n",
|
| 269 |
+
" preferred: list[str],\n",
|
| 270 |
+
" required_token: str,\n",
|
| 271 |
+
" exclude_tokens: tuple[str, ...] = (),\n",
|
| 272 |
+
") -> tuple[str, pd.Series]:\n",
|
| 273 |
+
" numeric = _numeric_columns(frame)\n",
|
| 274 |
+
" for column in preferred:\n",
|
| 275 |
+
" if column in numeric:\n",
|
| 276 |
+
" return column, numeric[column]\n",
|
| 277 |
+
" for column, series in numeric.items():\n",
|
| 278 |
+
" lowered = column.lower()\n",
|
| 279 |
+
" if required_token in lowered and not any(token in lowered for token in exclude_tokens):\n",
|
| 280 |
+
" return column, series\n",
|
| 281 |
+
" available = \", \".join(sorted(numeric))\n",
|
| 282 |
+
" raise RuntimeError(\n",
|
| 283 |
+
" f\"Could not find a numeric {required_token!r} metric. Available: {available}\"\n",
|
| 284 |
+
" )\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"def _plot_curve(column: str, values: pd.Series, ylabel: str, title: str, path: Path) -> None:\n",
|
| 288 |
+
" plot_df = pd.DataFrame(\n",
|
| 289 |
+
" {\n",
|
| 290 |
+
" \"step\": pd.to_numeric(df[\"step\"], errors=\"coerce\"),\n",
|
| 291 |
+
" \"value\": values,\n",
|
| 292 |
+
" }\n",
|
| 293 |
+
" ).dropna()\n",
|
| 294 |
+
" if plot_df.empty:\n",
|
| 295 |
+
" raise RuntimeError(f\"Metric {column!r} has no plottable values.\")\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" plt.figure(figsize=(8, 4.5))\n",
|
| 298 |
+
" plt.plot(plot_df[\"step\"], plot_df[\"value\"], marker=\"o\", linewidth=1.4, label=column)\n",
|
| 299 |
+
" if len(plot_df) >= 5:\n",
|
| 300 |
+
" window = min(10, max(2, len(plot_df) // 5))\n",
|
| 301 |
+
" rolling = plot_df[\"value\"].rolling(window=window, min_periods=1).mean()\n",
|
| 302 |
+
" plt.plot(plot_df[\"step\"], rolling, linewidth=2.2, label=f\"{window}-point rolling mean\")\n",
|
| 303 |
+
" plt.xlabel(\"GRPO step\")\n",
|
| 304 |
+
" plt.ylabel(ylabel)\n",
|
| 305 |
+
" plt.title(title)\n",
|
| 306 |
+
" plt.grid(alpha=0.25)\n",
|
| 307 |
+
" plt.legend()\n",
|
| 308 |
+
" plt.tight_layout()\n",
|
| 309 |
+
" plt.savefig(path, dpi=180, bbox_inches=\"tight\")\n",
|
| 310 |
+
" plt.show()\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"loss_col, loss_values = _choose_metric(\n",
|
| 314 |
+
" df,\n",
|
| 315 |
+
" preferred=[\"loss\", \"train_loss\"],\n",
|
| 316 |
+
" required_token=\"loss\",\n",
|
| 317 |
+
")\n",
|
| 318 |
+
"reward_col, reward_values = _choose_metric(\n",
|
| 319 |
+
" df,\n",
|
| 320 |
+
" preferred=[\"reward\", \"rewards/mean\", \"mean_reward\", \"reward_mean\", \"train/reward\"],\n",
|
| 321 |
+
" required_token=\"reward\",\n",
|
| 322 |
+
" exclude_tokens=(\"std\", \"variance\"),\n",
|
| 323 |
+
")\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"loss_png = EVIDENCE_DIR / \"grpo_loss_curve.png\"\n",
|
| 326 |
+
"reward_png = EVIDENCE_DIR / \"grpo_reward_curve.png\"\n",
|
| 327 |
+
"_plot_curve(loss_col, loss_values, \"loss\", \"SQLDrift GRPO Loss Curve\", loss_png)\n",
|
| 328 |
+
"_plot_curve(reward_col, reward_values, \"episode reward\", \"SQLDrift GRPO Reward Curve\", reward_png)\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"print(f\"Wrote metrics: {metrics_csv}\")\n",
|
| 331 |
+
"print(f\"Wrote loss curve: {loss_png}\")\n",
|
| 332 |
+
"print(f\"Wrote reward curve: {reward_png}\")"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "markdown",
|
| 337 |
+
"metadata": {},
|
| 338 |
+
"source": [
|
| 339 |
+
"## 6. Commit The Evidence\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"Once training completes, commit the notebook plus the generated evidence images:\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"```bash\n",
|
| 344 |
+
"git add training/sql_drift_grpo_training.ipynb training/evidence/grpo_loss_curve.png training/evidence/grpo_reward_curve.png training/evidence/grpo_metrics.csv training/evidence/grpo_config.json\n",
|
| 345 |
+
"git commit -m \"Add SQLDrift GRPO training notebook and evidence\"\n",
|
| 346 |
+
"```\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"Do not rely on WandB-only links or plots that live only in notebook output cells; the PNG files above are the durable artifacts evaluators can validate.\n"
|
| 349 |
+
]
|
| 350 |
+
}
|
| 351 |
+
],
|
| 352 |
+
"metadata": {
|
| 353 |
+
"kernelspec": {
|
| 354 |
+
"display_name": "Python 3",
|
| 355 |
+
"language": "python",
|
| 356 |
+
"name": "python3"
|
| 357 |
+
},
|
| 358 |
+
"language_info": {
|
| 359 |
+
"codemirror_mode": {
|
| 360 |
+
"name": "ipython",
|
| 361 |
+
"version": 3
|
| 362 |
+
},
|
| 363 |
+
"file_extension": ".py",
|
| 364 |
+
"mimetype": "text/x-python",
|
| 365 |
+
"name": "python",
|
| 366 |
+
"nbconvert_exporter": "python",
|
| 367 |
+
"pygments_lexer": "ipython3",
|
| 368 |
+
"version": "3.12"
|
| 369 |
+
}
|
| 370 |
+
},
|
| 371 |
+
"nbformat": 4,
|
| 372 |
+
"nbformat_minor": 5
|
| 373 |
+
}
|
training/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TRL/Unsloth GRPO training harness — P12/P13.
|
| 2 |
+
|
| 3 |
+
Public surface:
|
| 4 |
+
|
| 5 |
+
* :class:`training.config.GRPOConfig` / :class:`training.config.CurriculumConfig`
|
| 6 |
+
* :func:`training.prompt.render_system_prompt`
|
| 7 |
+
* :class:`training.random_agent.RandomAgent`
|
| 8 |
+
* :func:`training.grpo_train.train` (requires GPU + ``[train]`` extra)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig
|
| 14 |
+
from training.prompt import (
|
| 15 |
+
render_prompt_from_observation,
|
| 16 |
+
render_system_prompt,
|
| 17 |
+
)
|
| 18 |
+
from training.random_agent import RandomAgent
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"ALL_SCENARIOS",
|
| 22 |
+
"CurriculumConfig",
|
| 23 |
+
"GRPOConfig",
|
| 24 |
+
"RandomAgent",
|
| 25 |
+
"render_prompt_from_observation",
|
| 26 |
+
"render_system_prompt",
|
| 27 |
+
]
|
training/config.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training configuration dataclasses.
|
| 2 |
+
|
| 3 |
+
Holds every knob the :mod:`training.grpo_train` script or the eval CLI
|
| 4 |
+
needs, as plain, frozen dataclasses so they serialize cleanly to JSON
|
| 5 |
+
for experiment manifests.
|
| 6 |
+
|
| 7 |
+
Deliberately lightweight: do not import ``trl`` / ``unsloth`` /
|
| 8 |
+
``transformers`` at module import time. Those libraries are CUDA-heavy
|
| 9 |
+
and optional. ``grpo_train.py`` resolves them lazily.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import Literal
|
| 16 |
+
|
| 17 |
+
from utilities.env_loader import env_str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _load_all_scenarios() -> tuple[str, ...]:
|
| 21 |
+
from scenarios import iter_specs
|
| 22 |
+
|
| 23 |
+
return tuple(spec.scenario_id for spec in iter_specs())
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Derived from the live registry so training defaults stay in sync with
|
| 27 |
+
# the scenarios actually shipped under ``scenarios/``.
|
| 28 |
+
ALL_SCENARIOS: tuple[str, ...] = _load_all_scenarios()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class CurriculumConfig:
|
| 33 |
+
"""Scenario sampling policy for GRPO rollouts.
|
| 34 |
+
|
| 35 |
+
``mode="uniform"`` samples each id in :attr:`scenarios` with equal
|
| 36 |
+
probability. ``mode="weighted"`` uses :attr:`weights` (must be the
|
| 37 |
+
same length as :attr:`scenarios`) — useful for over-sampling drift
|
| 38 |
+
scenarios early in training. ``mode="static_order"`` iterates the
|
| 39 |
+
list round-robin (handy for reproducing eval-style runs).
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
scenarios: tuple[str, ...] = ALL_SCENARIOS
|
| 43 |
+
mode: Literal["uniform", "weighted", "static_order"] = "uniform"
|
| 44 |
+
weights: tuple[float, ...] | None = None
|
| 45 |
+
seed_range: tuple[int, int] = (0, 2**31 - 1)
|
| 46 |
+
|
| 47 |
+
def __post_init__(self) -> None:
|
| 48 |
+
if not self.scenarios:
|
| 49 |
+
raise ValueError("CurriculumConfig.scenarios must be non-empty")
|
| 50 |
+
if self.mode == "weighted":
|
| 51 |
+
if self.weights is None or len(self.weights) != len(self.scenarios):
|
| 52 |
+
raise ValueError("mode='weighted' requires weights of the same length as scenarios")
|
| 53 |
+
if any(w < 0 for w in self.weights):
|
| 54 |
+
raise ValueError("weights must all be >= 0")
|
| 55 |
+
if sum(self.weights) <= 0:
|
| 56 |
+
raise ValueError("at least one weight must be > 0")
|
| 57 |
+
lo, hi = self.seed_range
|
| 58 |
+
if lo < 0 or hi <= lo:
|
| 59 |
+
raise ValueError("seed_range must be (lo >= 0, hi > lo)")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass(frozen=True)
|
| 63 |
+
class GRPOConfig:
|
| 64 |
+
"""Top-level training config for the GRPO skeleton.
|
| 65 |
+
|
| 66 |
+
Defaults: Qwen3-1.7B (Unsloth bnb-4bit) + LoRA r=16, group size 4,
|
| 67 |
+
500 max steps — but every knob is override-able from the CLI or a
|
| 68 |
+
JSON/YAML manifest. Model swap from Qwen2.5-1.5B → Qwen3-1.7B per
|
| 69 |
+
design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md (section 4: newer family with
|
| 70 |
+
documented agentic tool-calling, comparable footprint).
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
model_name: str = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
|
| 74 |
+
max_seq_length: int = 4096
|
| 75 |
+
load_in_4bit: bool = True
|
| 76 |
+
|
| 77 |
+
lora_r: int = 16
|
| 78 |
+
lora_alpha: int = 16
|
| 79 |
+
lora_dropout: float = 0.0
|
| 80 |
+
lora_target_modules: tuple[str, ...] = (
|
| 81 |
+
"q_proj",
|
| 82 |
+
"k_proj",
|
| 83 |
+
"v_proj",
|
| 84 |
+
"o_proj",
|
| 85 |
+
"gate_proj",
|
| 86 |
+
"up_proj",
|
| 87 |
+
"down_proj",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# GRPO knobs
|
| 91 |
+
group_size: int = 4
|
| 92 |
+
learning_rate: float = 5e-6
|
| 93 |
+
max_steps: int = 500
|
| 94 |
+
gradient_accumulation_steps: int = 1
|
| 95 |
+
warmup_steps: int = 10
|
| 96 |
+
max_prompt_length: int = 2048
|
| 97 |
+
max_completion_length: int = 1024
|
| 98 |
+
temperature: float = 0.8
|
| 99 |
+
top_p: float = 0.95
|
| 100 |
+
seed: int = 0
|
| 101 |
+
|
| 102 |
+
# Env wiring
|
| 103 |
+
env_base_url: str = env_str("SQL_DRIFT_ENV_URL", "http://localhost:8000")
|
| 104 |
+
episode_step_budget: int = 25
|
| 105 |
+
dba_oracle_enabled: bool = False
|
| 106 |
+
|
| 107 |
+
# IO
|
| 108 |
+
output_dir: str = "outputs/grpo_run"
|
| 109 |
+
logging_steps: int = 1
|
| 110 |
+
save_steps: int = 100
|
| 111 |
+
|
| 112 |
+
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
|
| 113 |
+
|
| 114 |
+
def __post_init__(self) -> None:
|
| 115 |
+
if self.group_size < 2:
|
| 116 |
+
raise ValueError("GRPO group_size must be >= 2 (GRPO requires groups).")
|
| 117 |
+
if self.max_steps < 1:
|
| 118 |
+
raise ValueError("max_steps must be >= 1")
|
| 119 |
+
if self.seed < 0:
|
| 120 |
+
raise ValueError("seed must be >= 0")
|
| 121 |
+
if self.lora_r < 1:
|
| 122 |
+
raise ValueError("lora_r must be >= 1")
|
| 123 |
+
if not 0.0 < self.temperature <= 2.0:
|
| 124 |
+
raise ValueError("temperature must be in (0, 2]")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
__all__ = ["ALL_SCENARIOS", "CurriculumConfig", "GRPOConfig"]
|
training/eval.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation harness for SQLDrift.
|
| 2 |
+
|
| 3 |
+
K-rollouts-per-scenario sweep that emits a markdown ``report.md`` and a
|
| 4 |
+
row-level ``per_episode.csv`` so reviewers can eyeball pass rate and
|
| 5 |
+
reward distribution without spinning up a notebook.
|
| 6 |
+
|
| 7 |
+
Invocation::
|
| 8 |
+
|
| 9 |
+
python -m training.eval \
|
| 10 |
+
--checkpoint base \
|
| 11 |
+
--scenarios 1-10 \
|
| 12 |
+
--seeds-per-scenario 5 \
|
| 13 |
+
--out outputs/evals/<run_id>/
|
| 14 |
+
|
| 15 |
+
``--checkpoint base`` runs the CPU :class:`RandomAgent` as a baseline
|
| 16 |
+
(no LLM, no GPU). A non-"base" value is a pointer to a saved adapter
|
| 17 |
+
and will attempt the lazy Unsloth import path in :func:`load_agent`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import csv
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import statistics
|
| 27 |
+
import time
|
| 28 |
+
from collections.abc import Callable
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any, Protocol
|
| 32 |
+
|
| 33 |
+
from models import REWARD_COMPONENT_KEYS, SqlDriftAction, SqlDriftObservation
|
| 34 |
+
from server import SqlDriftEnvironment
|
| 35 |
+
from training.config import ALL_SCENARIOS
|
| 36 |
+
from training.random_agent import RandomAgent
|
| 37 |
+
from training.seeding import set_seed
|
| 38 |
+
from utilities.logger import get_module_logger
|
| 39 |
+
|
| 40 |
+
_LOG = get_module_logger(__name__)
|
| 41 |
+
|
| 42 |
+
PASS_REWARD_THRESHOLD = 0.5
|
| 43 |
+
"""An episode is a "pass" when its terminal reward (the step that flips
|
| 44 |
+
``obs.done``—the submit step or the last step before budget exhaustion)
|
| 45 |
+
meets this threshold. The rubric scores a correct submission at
|
| 46 |
+
``+1.0`` in :attr:`r_correct`, so ``0.5`` matches the
|
| 47 |
+
"correct-but-not-improved" knee. Per-step values accumulate in
|
| 48 |
+
``episode_return`` and are not used for the pass bit."""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
# Agent interface
|
| 53 |
+
# -----------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Agent(Protocol):
|
| 57 |
+
"""Duck-typed policy — :meth:`RandomAgent.act` fits this shape."""
|
| 58 |
+
|
| 59 |
+
def reset(self, seed: int | None = None, scenario_id: str | None = None) -> None: ...
|
| 60 |
+
def act(self, obs: SqlDriftObservation) -> SqlDriftAction: ...
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_agent(
|
| 64 |
+
checkpoint: str,
|
| 65 |
+
*,
|
| 66 |
+
seed: int = 0,
|
| 67 |
+
base_model: str | None = None,
|
| 68 |
+
temperature: float = 0.0,
|
| 69 |
+
) -> Agent:
|
| 70 |
+
"""Resolve a checkpoint spec to a concrete agent.
|
| 71 |
+
|
| 72 |
+
* ``base``/``random`` — CPU-only :class:`RandomAgent` baseline.
|
| 73 |
+
* Any other value is treated as a filesystem path (a full
|
| 74 |
+
Hugging Face checkpoint directory or a PEFT adapter directory). The
|
| 75 |
+
:class:`training.llm_agent.LLMAgent` is imported lazily so
|
| 76 |
+
CPU-only CI that never calls ``load_agent`` with a path never
|
| 77 |
+
has to install ``transformers``/``peft``.
|
| 78 |
+
|
| 79 |
+
``base_model`` is forwarded to :class:`LLMAgent` when the adapter
|
| 80 |
+
directory does not pin its base model; ``temperature=0`` (greedy)
|
| 81 |
+
is the default for deterministic eval sweeps.
|
| 82 |
+
"""
|
| 83 |
+
if checkpoint in ("base", "random"):
|
| 84 |
+
return RandomAgent(seed=seed)
|
| 85 |
+
|
| 86 |
+
path = Path(checkpoint)
|
| 87 |
+
if not path.exists():
|
| 88 |
+
raise FileNotFoundError(
|
| 89 |
+
f"checkpoint {checkpoint!r} is not 'base'/'random' and does not exist on disk"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Lazy import: keeps the heavy transformers/peft tree out of
|
| 93 |
+
# the default import surface for ``training.eval``.
|
| 94 |
+
from training.llm_agent import LLMAgent
|
| 95 |
+
|
| 96 |
+
return LLMAgent(
|
| 97 |
+
str(path),
|
| 98 |
+
base_model=base_model,
|
| 99 |
+
temperature=temperature,
|
| 100 |
+
seed=seed,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# -----------------------------------------------------------------------------
|
| 105 |
+
# Episode runner
|
| 106 |
+
# -----------------------------------------------------------------------------
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class EpisodeResult:
|
| 111 |
+
scenario_id: str
|
| 112 |
+
seed: int
|
| 113 |
+
terminal_reward: float
|
| 114 |
+
episode_return: float
|
| 115 |
+
steps: int
|
| 116 |
+
passed: bool
|
| 117 |
+
submitted: bool
|
| 118 |
+
drift_fired: bool
|
| 119 |
+
wall_ms: float
|
| 120 |
+
reward_components: dict[str, float] = field(default_factory=dict)
|
| 121 |
+
effective_speedup: float | None = None
|
| 122 |
+
|
| 123 |
+
def as_row(self) -> dict[str, str]:
|
| 124 |
+
row: dict[str, str] = {
|
| 125 |
+
"scenario_id": self.scenario_id,
|
| 126 |
+
"seed": str(self.seed),
|
| 127 |
+
"terminal_reward": f"{self.terminal_reward:.4f}",
|
| 128 |
+
"episode_return": f"{self.episode_return:.4f}",
|
| 129 |
+
"steps": str(self.steps),
|
| 130 |
+
"passed": "1" if self.passed else "0",
|
| 131 |
+
"submitted": "1" if self.submitted else "0",
|
| 132 |
+
"drift_fired": "1" if self.drift_fired else "0",
|
| 133 |
+
"wall_ms": f"{self.wall_ms:.2f}",
|
| 134 |
+
"effective_speedup": _format_speedup(self.effective_speedup),
|
| 135 |
+
}
|
| 136 |
+
for k in REWARD_COMPONENT_KEYS:
|
| 137 |
+
row[k] = f"{self.reward_components.get(k, 0.0):.4f}"
|
| 138 |
+
return row
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _format_speedup(value: float | None) -> str:
|
| 142 |
+
"""Render an effective_speedup cell for CSV output.
|
| 143 |
+
|
| 144 |
+
``None`` (no submission) and ``+∞`` (baseline invalidated by drift)
|
| 145 |
+
need distinct, non-numeric representations so a downstream parser
|
| 146 |
+
cannot conflate "no data" with "infinite" — both read back as empty
|
| 147 |
+
cells today, which would bias per-scenario means.
|
| 148 |
+
"""
|
| 149 |
+
if value is None:
|
| 150 |
+
return ""
|
| 151 |
+
if math.isinf(value):
|
| 152 |
+
return "inf"
|
| 153 |
+
return f"{value:.3f}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _effective_speedup(env: SqlDriftEnvironment) -> float | None:
|
| 157 |
+
"""Read the current episode's effective speedup through the env surface.
|
| 158 |
+
|
| 159 |
+
Kept as a thin shim so tests and reporting code don't need to reach
|
| 160 |
+
into env internals themselves.
|
| 161 |
+
"""
|
| 162 |
+
return env.effective_speedup()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _run_one_episode(
|
| 166 |
+
env: SqlDriftEnvironment,
|
| 167 |
+
agent: Agent,
|
| 168 |
+
*,
|
| 169 |
+
scenario_id: str,
|
| 170 |
+
seed: int,
|
| 171 |
+
max_steps: int = 30,
|
| 172 |
+
) -> EpisodeResult:
|
| 173 |
+
t0 = time.perf_counter()
|
| 174 |
+
obs = env.reset(seed=seed, scenario_id=scenario_id)
|
| 175 |
+
agent.reset(seed=seed, scenario_id=scenario_id)
|
| 176 |
+
|
| 177 |
+
episode_return = 0.0
|
| 178 |
+
terminal_reward = 0.0
|
| 179 |
+
last_components: dict[str, float] = {}
|
| 180 |
+
steps = 0
|
| 181 |
+
while not obs.done and steps < max_steps:
|
| 182 |
+
action = agent.act(obs)
|
| 183 |
+
obs = env.step(action)
|
| 184 |
+
if obs.reward is not None:
|
| 185 |
+
episode_return += obs.reward
|
| 186 |
+
terminal_reward = obs.reward
|
| 187 |
+
if obs.reward_components:
|
| 188 |
+
last_components = dict(obs.reward_components)
|
| 189 |
+
steps += 1
|
| 190 |
+
|
| 191 |
+
state = env.state
|
| 192 |
+
wall_ms = (time.perf_counter() - t0) * 1000.0
|
| 193 |
+
|
| 194 |
+
return EpisodeResult(
|
| 195 |
+
scenario_id=scenario_id,
|
| 196 |
+
seed=seed,
|
| 197 |
+
terminal_reward=terminal_reward,
|
| 198 |
+
episode_return=episode_return,
|
| 199 |
+
steps=steps,
|
| 200 |
+
passed=terminal_reward >= PASS_REWARD_THRESHOLD,
|
| 201 |
+
submitted=state.submitted,
|
| 202 |
+
drift_fired=state.drift_fired,
|
| 203 |
+
wall_ms=wall_ms,
|
| 204 |
+
reward_components=last_components,
|
| 205 |
+
effective_speedup=_effective_speedup(env),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def run_eval(
|
| 210 |
+
*,
|
| 211 |
+
checkpoint: str,
|
| 212 |
+
scenarios: list[str],
|
| 213 |
+
seeds_per_scenario: int,
|
| 214 |
+
out_dir: Path,
|
| 215 |
+
max_steps: int = 30,
|
| 216 |
+
base_seed: int = 0,
|
| 217 |
+
progress_cb: Callable[[int, int], None] | None = None,
|
| 218 |
+
base_model: str | None = None,
|
| 219 |
+
temperature: float = 0.0,
|
| 220 |
+
) -> dict[str, Any]:
|
| 221 |
+
"""Drive the full sweep. Returns the summary dict also written to JSON.
|
| 222 |
+
|
| 223 |
+
``base_model`` / ``temperature`` are forwarded to
|
| 224 |
+
:func:`load_agent` for LLM-checkpoint runs (ignored for
|
| 225 |
+
``base``/``random``).
|
| 226 |
+
"""
|
| 227 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 228 |
+
set_seed(base_seed)
|
| 229 |
+
agent = load_agent(
|
| 230 |
+
checkpoint,
|
| 231 |
+
seed=base_seed,
|
| 232 |
+
base_model=base_model,
|
| 233 |
+
temperature=temperature,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
results: list[EpisodeResult] = []
|
| 237 |
+
env = SqlDriftEnvironment()
|
| 238 |
+
total = len(scenarios) * seeds_per_scenario
|
| 239 |
+
done = 0
|
| 240 |
+
try:
|
| 241 |
+
for scenario_id in scenarios:
|
| 242 |
+
for k in range(seeds_per_scenario):
|
| 243 |
+
seed = base_seed + k
|
| 244 |
+
res = _run_one_episode(
|
| 245 |
+
env,
|
| 246 |
+
agent,
|
| 247 |
+
scenario_id=scenario_id,
|
| 248 |
+
seed=seed,
|
| 249 |
+
max_steps=max_steps,
|
| 250 |
+
)
|
| 251 |
+
results.append(res)
|
| 252 |
+
done += 1
|
| 253 |
+
if progress_cb:
|
| 254 |
+
progress_cb(done, total)
|
| 255 |
+
finally:
|
| 256 |
+
env.close()
|
| 257 |
+
|
| 258 |
+
_write_per_episode_csv(results, out_dir / "per_episode.csv")
|
| 259 |
+
summary = _build_summary(
|
| 260 |
+
results,
|
| 261 |
+
checkpoint=checkpoint,
|
| 262 |
+
scenarios=scenarios,
|
| 263 |
+
seeds_per_scenario=seeds_per_scenario,
|
| 264 |
+
)
|
| 265 |
+
(out_dir / "summary.json").write_text(json.dumps(summary, indent=2))
|
| 266 |
+
(out_dir / "report.md").write_text(render_report(summary, results))
|
| 267 |
+
return summary
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# -----------------------------------------------------------------------------
|
| 271 |
+
# Reporting
|
| 272 |
+
# -----------------------------------------------------------------------------
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _write_per_episode_csv(results: list[EpisodeResult], path: Path) -> None:
|
| 276 |
+
if not results:
|
| 277 |
+
path.write_text("")
|
| 278 |
+
return
|
| 279 |
+
fieldnames = list(results[0].as_row().keys())
|
| 280 |
+
with path.open("w", newline="") as f:
|
| 281 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 282 |
+
writer.writeheader()
|
| 283 |
+
for r in results:
|
| 284 |
+
writer.writerow(r.as_row())
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _speedup_stats(
|
| 288 |
+
speedups: list[float],
|
| 289 |
+
) -> tuple[float | None, int]:
|
| 290 |
+
"""Finite-mean + ``+∞`` count for a list of effective_speedup values.
|
| 291 |
+
|
| 292 |
+
``statistics.fmean`` on a list containing ``+∞`` returns ``+∞``,
|
| 293 |
+
which poisons JSON output (``json.dumps`` rejects ``Infinity`` by
|
| 294 |
+
default) and washes out the finite distribution. We split the two
|
| 295 |
+
channels: the reported mean is over *finite* speedups only, and the
|
| 296 |
+
``infinite_count`` is surfaced separately so reviewers can see how
|
| 297 |
+
often drift invalidated the baseline.
|
| 298 |
+
"""
|
| 299 |
+
finite = [s for s in speedups if not math.isinf(s)]
|
| 300 |
+
infinite_count = len(speedups) - len(finite)
|
| 301 |
+
mean = statistics.fmean(finite) if finite else None
|
| 302 |
+
return mean, infinite_count
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _build_summary(
|
| 306 |
+
results: list[EpisodeResult],
|
| 307 |
+
*,
|
| 308 |
+
checkpoint: str,
|
| 309 |
+
scenarios: list[str],
|
| 310 |
+
seeds_per_scenario: int,
|
| 311 |
+
) -> dict[str, Any]:
|
| 312 |
+
by_scenario: dict[str, dict[str, Any]] = {}
|
| 313 |
+
for sid in scenarios:
|
| 314 |
+
scoped = [r for r in results if r.scenario_id == sid]
|
| 315 |
+
if not scoped:
|
| 316 |
+
continue
|
| 317 |
+
terminals = [r.terminal_reward for r in scoped]
|
| 318 |
+
returns = [r.episode_return for r in scoped]
|
| 319 |
+
speedups = [r.effective_speedup for r in scoped if r.effective_speedup is not None]
|
| 320 |
+
mean_sp, inf_sp = _speedup_stats(speedups)
|
| 321 |
+
by_scenario[sid] = {
|
| 322 |
+
"n": len(scoped),
|
| 323 |
+
"pass_rate": sum(1 for r in scoped if r.passed) / len(scoped),
|
| 324 |
+
"mean_terminal_reward": statistics.fmean(terminals),
|
| 325 |
+
"std_terminal_reward": statistics.pstdev(terminals) if len(terminals) > 1 else 0.0,
|
| 326 |
+
"mean_episode_return": statistics.fmean(returns),
|
| 327 |
+
"submit_rate": sum(1 for r in scoped if r.submitted) / len(scoped),
|
| 328 |
+
"mean_effective_speedup": mean_sp,
|
| 329 |
+
"infinite_speedup_count": inf_sp,
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
all_speedups = [r.effective_speedup for r in results if r.effective_speedup is not None]
|
| 333 |
+
mean_sp_all, inf_sp_all = _speedup_stats(all_speedups)
|
| 334 |
+
overall = {
|
| 335 |
+
"checkpoint": checkpoint,
|
| 336 |
+
"n_episodes": len(results),
|
| 337 |
+
"seeds_per_scenario": seeds_per_scenario,
|
| 338 |
+
"pass_rate": sum(1 for r in results if r.passed) / len(results) if results else 0.0,
|
| 339 |
+
"mean_terminal_reward": (
|
| 340 |
+
statistics.fmean(r.terminal_reward for r in results) if results else 0.0
|
| 341 |
+
),
|
| 342 |
+
"std_terminal_reward": (
|
| 343 |
+
statistics.pstdev([r.terminal_reward for r in results]) if len(results) > 1 else 0.0
|
| 344 |
+
),
|
| 345 |
+
"mean_episode_return": (
|
| 346 |
+
statistics.fmean(r.episode_return for r in results) if results else 0.0
|
| 347 |
+
),
|
| 348 |
+
"submit_rate": sum(1 for r in results if r.submitted) / len(results) if results else 0.0,
|
| 349 |
+
"mean_effective_speedup": mean_sp_all,
|
| 350 |
+
"infinite_speedup_count": inf_sp_all,
|
| 351 |
+
}
|
| 352 |
+
return {"overall": overall, "by_scenario": by_scenario}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def render_report(
|
| 356 |
+
summary: dict[str, Any],
|
| 357 |
+
results: list[EpisodeResult],
|
| 358 |
+
) -> str:
|
| 359 |
+
"""Compose a reviewer-friendly ``report.md`` string."""
|
| 360 |
+
overall = summary["overall"]
|
| 361 |
+
lines: list[str] = []
|
| 362 |
+
lines.append("# SQLDrift evaluation report")
|
| 363 |
+
lines.append("")
|
| 364 |
+
speedup = overall.get("mean_effective_speedup")
|
| 365 |
+
speedup_cell = f"{speedup:.2f}x" if speedup is not None else "—"
|
| 366 |
+
lines.append(f"- Checkpoint: `{overall['checkpoint']}`")
|
| 367 |
+
lines.append(f"- Episodes: **{overall['n_episodes']}**")
|
| 368 |
+
lines.append(f"- Seeds/scenario: {overall['seeds_per_scenario']}")
|
| 369 |
+
lines.append(
|
| 370 |
+
f"- Overall pass rate (terminal reward ≥ {PASS_REWARD_THRESHOLD}): "
|
| 371 |
+
f"**{overall['pass_rate']:.1%}**"
|
| 372 |
+
)
|
| 373 |
+
lines.append(
|
| 374 |
+
f"- Mean terminal reward: **{overall['mean_terminal_reward']:.3f}** "
|
| 375 |
+
f"(σ = {overall['std_terminal_reward']:.3f})"
|
| 376 |
+
)
|
| 377 |
+
lines.append(f"- Mean episode return: {overall['mean_episode_return']:.3f}")
|
| 378 |
+
lines.append(f"- Submit rate: {overall['submit_rate']:.1%}")
|
| 379 |
+
lines.append(f"- Mean effective speedup (finite, submitted episodes): **{speedup_cell}**")
|
| 380 |
+
inf_count = overall.get("infinite_speedup_count", 0)
|
| 381 |
+
if inf_count:
|
| 382 |
+
lines.append(f"- Infinite-speedup episodes (drift invalidated baseline): **{inf_count}**")
|
| 383 |
+
lines.append("")
|
| 384 |
+
|
| 385 |
+
lines.append("## Per-scenario")
|
| 386 |
+
lines.append("")
|
| 387 |
+
lines.append("| Scenario | N | Pass | Terminal μ | Return μ | Submit | Speedup |")
|
| 388 |
+
lines.append("|----------|---|------|-----------|----------|--------|---------|")
|
| 389 |
+
for sid, row in summary["by_scenario"].items():
|
| 390 |
+
sp = row.get("mean_effective_speedup")
|
| 391 |
+
sp_cell = f"{sp:.2f}x" if sp is not None else "—"
|
| 392 |
+
lines.append(
|
| 393 |
+
f"| `{sid}` | {row['n']} | {row['pass_rate']:.0%} | "
|
| 394 |
+
f"{row['mean_terminal_reward']:.3f} | {row['mean_episode_return']:.3f} | "
|
| 395 |
+
f"{row['submit_rate']:.0%} | {sp_cell} |"
|
| 396 |
+
)
|
| 397 |
+
lines.append("")
|
| 398 |
+
|
| 399 |
+
lines.append("## Reward-component bars (mean across episodes)")
|
| 400 |
+
lines.append("")
|
| 401 |
+
lines.append(_render_component_bars(results))
|
| 402 |
+
lines.append("")
|
| 403 |
+
|
| 404 |
+
lines.append("## Notes")
|
| 405 |
+
lines.append("")
|
| 406 |
+
lines.append(
|
| 407 |
+
"- Random baseline does NOT attempt rewrites intelligently; "
|
| 408 |
+
"nonzero pass rate here is a lower bound on a well-trained agent."
|
| 409 |
+
)
|
| 410 |
+
lines.append(
|
| 411 |
+
"- Pass threshold is set at reward ≥ "
|
| 412 |
+
f"{PASS_REWARD_THRESHOLD}, matching the rubric's "
|
| 413 |
+
'"correct-but-not-improved" +0.5 partial credit (rubric).'
|
| 414 |
+
)
|
| 415 |
+
return "\n".join(lines) + "\n"
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _render_component_bars(results: list[EpisodeResult]) -> str:
|
| 419 |
+
if not results:
|
| 420 |
+
return "_no data_"
|
| 421 |
+
sums: dict[str, float] = {k: 0.0 for k in REWARD_COMPONENT_KEYS}
|
| 422 |
+
counts: dict[str, int] = {k: 0 for k in REWARD_COMPONENT_KEYS}
|
| 423 |
+
for r in results:
|
| 424 |
+
for k, v in r.reward_components.items():
|
| 425 |
+
if k in sums:
|
| 426 |
+
sums[k] += v
|
| 427 |
+
counts[k] += 1
|
| 428 |
+
means = {k: (sums[k] / counts[k] if counts[k] else 0.0) for k in sums}
|
| 429 |
+
|
| 430 |
+
# ASCII bar with sign (+/-).
|
| 431 |
+
max_abs = max((abs(v) for v in means.values()), default=1.0) or 1.0
|
| 432 |
+
width = 30
|
| 433 |
+
lines = ["```"]
|
| 434 |
+
for k in REWARD_COMPONENT_KEYS:
|
| 435 |
+
v = means[k]
|
| 436 |
+
bar_len = int(round(abs(v) / max_abs * width))
|
| 437 |
+
bar = ("█" * bar_len) if v >= 0 else ("▒" * bar_len)
|
| 438 |
+
lines.append(f"{k:<14} {v:+7.3f} {bar}")
|
| 439 |
+
lines.append("```")
|
| 440 |
+
return "\n".join(lines)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# -----------------------------------------------------------------------------
|
| 444 |
+
# CLI
|
| 445 |
+
# -----------------------------------------------------------------------------
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _expand_scenarios(spec: str) -> list[str]:
|
| 449 |
+
"""Accept either ``1-10``, ``1,3,5``, or a comma-list of raw ids."""
|
| 450 |
+
if "-" in spec and all(part.isdigit() for part in spec.split("-")):
|
| 451 |
+
lo, hi = (int(x) for x in spec.split("-"))
|
| 452 |
+
want_range: set[int] = set(range(lo, hi + 1))
|
| 453 |
+
return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_range]
|
| 454 |
+
if all(part.strip().isdigit() for part in spec.split(",")):
|
| 455 |
+
want_set = {int(part) for part in spec.split(",")}
|
| 456 |
+
return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_set]
|
| 457 |
+
return [s.strip() for s in spec.split(",")]
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def main(argv: list[str] | None = None) -> None:
|
| 461 |
+
ap = argparse.ArgumentParser(description="Evaluate an agent on SQLDrift.")
|
| 462 |
+
ap.add_argument("--checkpoint", required=True, help="'base' or adapter path")
|
| 463 |
+
ap.add_argument(
|
| 464 |
+
"--scenarios",
|
| 465 |
+
default="1-10",
|
| 466 |
+
help="Scenario range/list (e.g. '1-10', '1,3,5') or raw ids",
|
| 467 |
+
)
|
| 468 |
+
ap.add_argument("--seeds-per-scenario", type=int, default=5)
|
| 469 |
+
ap.add_argument("--out", type=Path, required=True)
|
| 470 |
+
ap.add_argument("--max-steps", type=int, default=30)
|
| 471 |
+
ap.add_argument("--base-seed", type=int, default=0)
|
| 472 |
+
ap.add_argument(
|
| 473 |
+
"--base-model",
|
| 474 |
+
default=None,
|
| 475 |
+
help="Override the base model id when --checkpoint points at a PEFT adapter",
|
| 476 |
+
)
|
| 477 |
+
ap.add_argument(
|
| 478 |
+
"--temperature",
|
| 479 |
+
type=float,
|
| 480 |
+
default=0.0,
|
| 481 |
+
help="Decoding temperature for LLM checkpoints (0 = greedy / deterministic).",
|
| 482 |
+
)
|
| 483 |
+
ns = ap.parse_args(argv)
|
| 484 |
+
|
| 485 |
+
scenarios = _expand_scenarios(ns.scenarios)
|
| 486 |
+
if not scenarios:
|
| 487 |
+
raise SystemExit(f"no scenarios matched spec {ns.scenarios!r}")
|
| 488 |
+
|
| 489 |
+
def _prog(done: int, total: int) -> None:
|
| 490 |
+
if done == total or done % max(1, total // 10) == 0:
|
| 491 |
+
_LOG.info("eval: %d/%d episodes", done, total)
|
| 492 |
+
|
| 493 |
+
summary = run_eval(
|
| 494 |
+
checkpoint=ns.checkpoint,
|
| 495 |
+
scenarios=scenarios,
|
| 496 |
+
seeds_per_scenario=ns.seeds_per_scenario,
|
| 497 |
+
out_dir=ns.out,
|
| 498 |
+
max_steps=ns.max_steps,
|
| 499 |
+
base_seed=ns.base_seed,
|
| 500 |
+
progress_cb=_prog,
|
| 501 |
+
base_model=ns.base_model,
|
| 502 |
+
temperature=ns.temperature,
|
| 503 |
+
)
|
| 504 |
+
print(json.dumps(summary["overall"], indent=2))
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
main()
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
__all__ = [
|
| 512 |
+
"Agent",
|
| 513 |
+
"EpisodeResult",
|
| 514 |
+
"PASS_REWARD_THRESHOLD",
|
| 515 |
+
"load_agent",
|
| 516 |
+
"main",
|
| 517 |
+
"render_report",
|
| 518 |
+
"run_eval",
|
| 519 |
+
]
|
training/evidence/grpo_components_combined.png
ADDED
|
Git LFS Details
|
training/evidence/grpo_loss_curve.png
ADDED
|
training/evidence/grpo_metrics.csv
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
step,loss,reward,r_correct,r_drift,r_step_tax
|
| 2 |
+
1,1.5070857702340332,-1.089657810470115,0.0060142775960789485,-0.05147482162696701,-0.5153269778837614
|
| 3 |
+
2,1.4163939988607464,-0.8820803512725144,0.03844503923645059,-0.10545186822296794,-0.4861222451750179
|
| 4 |
+
3,1.4296624851767596,-0.9146981630547505,0.10331660081021651,0.07228358802913387,-0.4961777953169654
|
| 5 |
+
4,1.3142921668473637,-0.5595880667701431,0.08344587035631848,0.059823292014677544,-0.520072970343028
|
| 6 |
+
5,1.2701518625957149,-0.5996319117518063,0.2610718206538336,0.004943992480230808,-0.4840739746024923
|
| 7 |
+
6,1.2640116521205307,-0.8258103297724467,0.26761100025304996,0.1913788646035109,-0.49710482987102533
|
| 8 |
+
7,1.1209233460223358,-0.6211756254418923,0.27404383444577135,0.200216496595977,-0.4682072435811402
|
| 9 |
+
8,1.1774670143175638,-0.619087388960126,0.270469688082585,0.17365149920126927,-0.4864816492894405
|
| 10 |
+
9,1.1983245069119268,-0.5491623820497955,0.3851815028189252,0.14236176514396473,-0.47340387335667083
|
| 11 |
+
10,1.1900131680220651,-0.4531450217747221,0.46035015896826725,0.20935614488999046,-0.3977678494548791
|
| 12 |
+
11,1.0778750024182207,-0.5529804804536397,0.4784254548477095,0.25126495579288227,-0.4914229994698512
|
| 13 |
+
12,1.0142047885469356,-0.30198009788159785,0.5200582520058386,0.24302822958435058,-0.46425248430094396
|
| 14 |
+
13,1.0045646774653714,-0.4645827945838715,0.5862339331205229,0.27468662267016536,-0.4534963800694839
|
| 15 |
+
14,0.9116167040003166,-0.3810888813065016,0.5319641343757862,0.2831807948097964,-0.43758017046734105
|
| 16 |
+
15,0.9119484849618256,-0.19031109272031552,0.544487836273619,0.30242310302170133,-0.43358305018602905
|
| 17 |
+
16,0.8631872796395779,-0.1902270628270999,0.5378009235565281,0.46520136753468744,-0.4378038803670464
|
| 18 |
+
17,0.9108494430815297,-0.28522550174067374,0.6317914174443212,0.39742825226075135,-0.4194535022660978
|
| 19 |
+
18,0.7414210838734385,-0.026510480876009573,0.6298982943138216,0.49275288649025134,-0.4263594599419731
|
| 20 |
+
19,0.8025799827262751,-0.09087146161792733,0.69004717240886,0.4924557349195555,-0.3603965798497102
|
| 21 |
+
20,0.6973776508874913,0.011148493528077183,0.7800312956412032,0.456173588246084,-0.4150265290706628
|
training/evidence/grpo_r_correct_curve.png
ADDED
|
training/evidence/grpo_r_drift_curve.png
ADDED
|