visheshrathi commited on
Commit
5850885
·
verified ·
1 Parent(s): 413db20

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +95 -0
  3. README.md +151 -6
  4. __init__.py +29 -0
  5. _cli.py +47 -0
  6. actors/__init__.py +1 -0
  7. actors/dba_oracle.py +183 -0
  8. actors/engineering_manager.py +81 -0
  9. client.py +121 -0
  10. engine/__init__.py +1 -0
  11. engine/drift.py +242 -0
  12. engine/profiler.py +298 -0
  13. engine/reward.py +437 -0
  14. engine/runtime.py +101 -0
  15. engine/verifier.py +101 -0
  16. models.py +364 -0
  17. openenv.yaml +7 -0
  18. pyproject.toml +163 -0
  19. scenarios/__init__.py +72 -0
  20. scenarios/_fixtures.py +119 -0
  21. scenarios/base.py +230 -0
  22. scenarios/scenario_01_correlated_subquery.py +93 -0
  23. scenarios/scenario_02_select_star_join.py +119 -0
  24. scenarios/scenario_03_cartesian_join.py +93 -0
  25. scenarios/scenario_04_distinct_groupby.py +73 -0
  26. scenarios/scenario_05_nested_subquery.py +100 -0
  27. scenarios/scenario_06_having_as_where.py +92 -0
  28. scenarios/scenario_07_drift_column_rename.py +73 -0
  29. scenarios/scenario_08_drift_date_format.py +93 -0
  30. scenarios/scenario_09_drift_enum_rule.py +85 -0
  31. scenarios/scenario_10_drift_field_deprecation.py +74 -0
  32. server/__init__.py +5 -0
  33. server/app.py +89 -0
  34. server/requirements.txt +116 -0
  35. server/settings.py +31 -0
  36. server/sql_drift_env_environment.py +899 -0
  37. skill_library/__init__.py +54 -0
  38. skill_library/entries.py +93 -0
  39. skill_library/retrieval.py +85 -0
  40. skill_library/seeds.py +126 -0
  41. skill_library/store.py +257 -0
  42. sql_drift_grpo_training.ipynb +373 -0
  43. training/__init__.py +27 -0
  44. training/config.py +127 -0
  45. training/eval.py +519 -0
  46. training/evidence/grpo_components_combined.png +3 -0
  47. training/evidence/grpo_loss_curve.png +0 -0
  48. training/evidence/grpo_metrics.csv +21 -0
  49. training/evidence/grpo_r_correct_curve.png +0 -0
  50. training/evidence/grpo_r_drift_curve.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ training/evidence/grpo_components_combined.png filter=lfs diff=lfs merge=lfs -text
37
+ training/evidence/grpo_reward_curve.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ # ─── IMAGE PINNING ────────────────────────────────────────────────────────────
14
+ # Using :latest is a moving target: a base-image update between local build
15
+ # and judging rebuild can pull a different Python / CUDA / OS version.
16
+ #
17
+ # TO PIN (run once, commit the digest):
18
+ # docker pull ghcr.io/meta-pytorch/openenv-base:latest
19
+ # docker inspect --format='{{index .RepoDigests 0}}' ghcr.io/meta-pytorch/openenv-base:latest
20
+ # # → ghcr.io/meta-pytorch/openenv-base@sha256:<digest>
21
+ #
22
+ # Then replace the ARG default below with the digest, e.g.:
23
+ # ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base@sha256:<digest>
24
+ #
25
+ # The sha256 digest is immutable; the image will not change between builds.
26
+ # ──────────────────────────────────────────────────────────────────────────────
27
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
28
+ FROM ${BASE_IMAGE} AS builder
29
+
30
+ WORKDIR /app
31
+
32
+ # Ensure git is available (required for installing dependencies from VCS)
33
+ RUN apt-get update && \
34
+ apt-get install -y --no-install-recommends git && \
35
+ rm -rf /var/lib/apt/lists/*
36
+
37
+ # Build argument to control whether we're building standalone or in-repo
38
+ ARG BUILD_MODE=in-repo
39
+ ARG ENV_NAME=sql_drift_env
40
+
41
+ # Copy environment code (always at root of build context)
42
+ COPY . /app/env
43
+
44
+ # For in-repo builds, openenv is already vendored in the build context
45
+ # For standalone builds, openenv will be installed via pyproject.toml
46
+ WORKDIR /app/env
47
+
48
+ # Ensure uv is available (for local builds where base image lacks it)
49
+ RUN if ! command -v uv >/dev/null 2>&1; then \
50
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
51
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
52
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
53
+ fi
54
+
55
+ # Install dependencies from the frozen lock file. Failing fast when uv.lock
56
+ # is absent prevents a silent re-resolution that could pull different versions
57
+ # than the ones tested (see design/codereview.md, dependency pins).
58
+ RUN --mount=type=cache,target=/root/.cache/uv \
59
+ uv sync --frozen --no-install-project --no-editable
60
+
61
+ RUN --mount=type=cache,target=/root/.cache/uv \
62
+ uv sync --frozen --no-editable
63
+
64
+ # Final runtime stage — inherits the same pinned/latest image chosen above.
65
+ FROM ${BASE_IMAGE}
66
+
67
+ WORKDIR /app
68
+
69
+ # Copy the uv-managed Python interpreter from the builder so that the
70
+ # shebangs baked into ``/app/env/.venv/bin/*`` (e.g. uvicorn) resolve.
71
+ # Without this, execve() on those scripts silently falls back to the
72
+ # system ``/usr/local/bin/uvicorn`` (Python 3.11) and loads the wrong
73
+ # site-packages, yielding ``ModuleNotFoundError: duckdb`` at startup.
74
+ COPY --from=builder /root/.local/share/uv /root/.local/share/uv
75
+
76
+ # Copy the environment code (includes the builder's .venv at /app/env/.venv)
77
+ COPY --from=builder /app/env /app/env
78
+
79
+ # Point PATH + PYTHONPATH at the venv that shipped with the env so
80
+ # both ``uvicorn ...`` and ``python -m uvicorn ...`` resolve against
81
+ # the right interpreter.
82
+ ENV PATH="/app/env/.venv/bin:$PATH"
83
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
84
+
85
+ # Health check (hits the stateless /health endpoint exposed by the
86
+ # OpenEnv FastAPI factory).
87
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
88
+ CMD curl -f http://localhost:8000/health || exit 1
89
+
90
+ # Run the FastAPI server via the venv's Python interpreter directly
91
+ # (bypassing any PATH-lookup ambiguity with the base image's system
92
+ # uvicorn).
93
+ ENV ENABLE_WEB_INTERFACE=true
94
+ CMD ["/app/env/.venv/bin/python", "-m", "uvicorn", "server.app:app", \
95
+ "--host", "0.0.0.0", "--port", "8000", "--app-dir", "/app/env"]
README.md CHANGED
@@ -1,10 +1,155 @@
1
  ---
2
- title: Sql Drift Env
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: docker
7
- pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SQLDrift
3
+ emoji: 🐘
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 8000
8
+ short_description: SQL repair and optimization under live schema drift
9
+ base_path: /web
10
  ---
11
 
12
+ # SQLDrift
13
+
14
+ > An OpenEnv gym that teaches an LLM agent to **repair and optimize SQL** when
15
+ > the database schema and business rules shift out from under it.
16
+
17
+ SQLDrift is a production-grade [OpenEnv](https://github.com/meta-pytorch/OpenEnv)
18
+ environment designed for RL fine-tuning with TRL's
19
+ [`GRPOTrainer`](https://huggingface.co/docs/trl/openenv) (plus
20
+ [Unsloth](https://github.com/unslothai/unsloth) LoRA on consumer GPUs).
21
+ Each episode hands the agent a slow baseline SQL query, a lightly-populated
22
+ DuckDB database, a toolbox of read-only introspection / execution tools, and a
23
+ **25-step budget**. Midway through the episode a schema or business-rule
24
+ **drift event** may fire; the agent must read the deploy changelog, re-diagnose,
25
+ and submit a rewrite that is both (a) semantically correct and (b) ≥ 1.2× faster
26
+ than the baseline. A hierarchical rubric turns that into six reward signals
27
+ (correctness, drift-adapt, speedup, step-tax, gatekeepers, consult-dba).
28
+
29
+ Highlights:
30
+
31
+ - **10 hand-crafted scenarios** — 6 classic anti-patterns (correlated
32
+ subqueries, `SELECT *` joins, Cartesian joins, `DISTINCT` groupbys, nested
33
+ subqueries, `HAVING`-as-`WHERE`) and 4 drift events (column rename, date
34
+ format flip, enum rule split, field deprecation).
35
+ - **Deterministic fixtures** — 15–30-column schemas, 2–4 tables per scenario,
36
+ regenerated in-process at `reset()` from a seed. No Parquet files, no
37
+ pre-baked sqlite dumps.
38
+ - **Sqlglot-canonicalised baseline-verbatim gate** — agents that rename
39
+ whitespace and resubmit the baseline cannot farm the +1.0 correct bonus.
40
+ - **Skill library** — 12 pre-seeded playbook/drift-card entries plus an
41
+ on-disk, `fcntl.flock`-guarded JSON store that grows as the agent solves
42
+ episodes, surfaced next `reset()` via Jaccard retrieval.
43
+ - **Feature-flagged DBA Oracle** — 3-tier escalating hints per scenario,
44
+ penalised by the `ConsultDBA` rubric. Off by default.
45
+ - **Dockerised** — `server/Dockerfile` layers the env on
46
+ `ghcr.io/meta-pytorch/openenv-base` and exposes `/health`, `/reset`,
47
+ `/step`, `/ws` out-of-the-box.
48
+
49
+ ## Quick start
50
+
51
+ ```bash
52
+ git clone <your-fork>/sql_drift_env.git
53
+ cd sql_drift_env
54
+ uv sync # install deps
55
+ uv run pytest -q # 300+ tests, ~90s on CPU
56
+ uv run uvicorn server.app:app --reload # serve the env on :8000
57
+ ```
58
+
59
+ Or via Docker:
60
+
61
+ ```bash
62
+ docker build -f server/Dockerfile -t sql-drift-env:latest .
63
+ docker run -p 8000:8000 sql-drift-env:latest
64
+ curl -s http://localhost:8000/health
65
+ ```
66
+
67
+ ## Programmatic rollout
68
+
69
+ A five-step in-process episode against scenario `01_correlated_subquery`:
70
+
71
+ ```python
72
+ from client import SqlDriftEnv
73
+ from server import SqlDriftEnvironment
74
+
75
+ env = SqlDriftEnvironment()
76
+ obs = env.reset(seed=42, scenario_id="01_correlated_subquery")
77
+ print(obs.learned_hints) # surfaced from skill library
78
+
79
+ obs = env.step(SqlDriftEnv.action_list_tables())
80
+ obs = env.step(SqlDriftEnv.action_describe_table("users"))
81
+ obs = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM users"))
82
+ obs = env.step(SqlDriftEnv.action_submit_rewrite(
83
+ "SELECT u.*, COALESCE(c.n, 0) FROM users u "
84
+ "LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders GROUP BY 1) c "
85
+ "ON c.user_id = u.user_id"
86
+ ))
87
+ print(obs.reward, obs.reward_components)
88
+ env.close()
89
+ ```
90
+
91
+ End-to-end over an HTTP+WS OpenEnv server, see `SqlDriftEnv` in
92
+ `client.py` and the integration test suite in
93
+ `tests/integration/test_client_server.py` /
94
+ `tests/integration/test_state_no_leak.py`.
95
+
96
+ ## Evaluation
97
+
98
+ ```bash
99
+ uv run python -m training.eval \
100
+ --checkpoint base \
101
+ --scenarios 1-10 \
102
+ --seeds-per-scenario 5 \
103
+ --out outputs/evals/my_run
104
+ ```
105
+
106
+ Emits `report.md`, `per_episode.csv`, and `summary.json`. A baseline random
107
+ agent sweep is checked in at `outputs/evals/baseline_random_v1/report.md` as a
108
+ floor: mean reward ≈ **-2.05**, pass rate **0 %**.
109
+
110
+ ## Training (GPU)
111
+
112
+ `training/grpo_train.py` contains the GRPO entrypoint used by the hackathon
113
+ training notebook: it builds the curriculum dataset, loads
114
+ `unsloth/Qwen3-1.7B-unsloth-bnb-4bit`, attaches PEFT LoRA adapters, and lets TRL's
115
+ `GRPOTrainer` drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`.
116
+ Run `training/sql_drift_grpo_training.ipynb` on a GPU runtime with
117
+ `SQL_DRIFT_ENV_URL` pointed at the deployed SQLDrift Space; the notebook writes
118
+ durable evidence plots to `training/evidence/`.
119
+
120
+ ```bash
121
+ uv sync --extra train # installs trl, transformers, accelerate, unsloth
122
+ uv sync --extra evidence # matplotlib + pandas for utilities/plot_curves.py
123
+ ```
124
+
125
+ ## Repository layout
126
+
127
+ ```
128
+ sql_drift_env/
129
+ ├── models.py # Pydantic v2 action/observation/state
130
+ ├── client.py # SqlDriftEnv EnvClient (/ws)
131
+ ├── engine/
132
+ │ ├── runtime.py # private RuntimeEpisodeState
133
+ │ ├── drift.py # 4 DDL drift operators
134
+ │ ├── reward.py # SqlDriftRubric (6 child rubrics)
135
+ │ ├── verifier.py / profiler.py
136
+ ├── scenarios/ # 10 hand-crafted scenario modules + registry
137
+ ├── skill_library/ # pre-seeds + JSON store + Jaccard retrieval
138
+ ├── actors/ # engineering_manager (changelog), dba_oracle
139
+ ├── server/ # FastAPI app, Dockerfile, env wrapper class
140
+ ├── training/ # config, prompt, random_agent, grpo_train, eval
141
+ ├── utilities/ # env_loader, logger, plot_curves, run_training_job (HF Jobs), …
142
+ ├── tests/ # 300+ unit + integration tests
143
+ ├── outputs/evals/ # baseline eval artifacts
144
+ └── design/ # design docs (ignored by docker)
145
+ ```
146
+
147
+ ## Design references
148
+
149
+ - `.cursor/plans/sqldrift_technical_spec_feb249d1.plan.md` — full 16-phase
150
+ technical spec (Rev-3).
151
+ - `design/CLAUDE.md` — agent coding guidelines.
152
+
153
+ ## License
154
+
155
+ BSD-3-Clause (matches the OpenEnv project template).
__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLDrift — OpenEnv gym for SQL repair & optimization under drift.
2
+
3
+ The repo uses a flat top-level module layout (``models.py``,
4
+ ``client.py``, ``engine/``, ``scenarios/``, ``skill_library/``,
5
+ ``actors/``, ``server/``, ``training/``) because it is also run as a
6
+ FastAPI server that imports siblings absolutely (``from models import
7
+ …``). setuptools republishes ``.`` as the ``sql_drift_env`` package so
8
+ both import styles work at runtime, but eagerly re-exporting the flat
9
+ submodules from here would shadow the top-level ``import models`` /
10
+ ``import client`` paths that every flat module relies on, and would
11
+ make the import order pytest-collection-sensitive.
12
+
13
+ The public API for agent code is therefore the flat modules themselves,
14
+ imported directly:
15
+
16
+ from client import SqlDriftEnv
17
+ from models import SqlDriftAction, SqlDriftObservation
18
+ from server import SqlDriftEnvironment
19
+
20
+ This mirrors the flat layout both on disk and at import time; the
21
+ ``sql_drift_env`` namespace exists only so the wheel has a canonical
22
+ name and so third parties can depend on a stable version string.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ __version__ = "0.1.0"
28
+
29
+ __all__ = ["__version__"]
_cli.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console-script entrypoint that patches sys.path for the flat-import layout.
2
+
3
+ Problem
4
+ -------
5
+ The project uses a *flat layout* (``pyproject.toml`` maps ``.`` → the
6
+ ``sql_drift_env`` package directory). That means every sibling module
7
+ (``models``, ``actors``, ``engine`` …) is imported as a plain top-level
8
+ name rather than via the ``sql_drift_env.`` prefix.
9
+
10
+ When the wheel is installed, those siblings land at
11
+ ``site-packages/sql_drift_env/models.py`` etc., *not* at the top-level
12
+ ``site-packages/`` directory. A naïve console-script that calls
13
+ ``sql_drift_env.server.app:main`` would fail at ``from models import …``
14
+ before reaching any application logic.
15
+
16
+ Fix
17
+ ---
18
+ Insert the installed package directory (``site-packages/sql_drift_env/``)
19
+ onto ``sys.path`` *before* importing anything from the server package.
20
+ This mirrors what Docker achieves via ``--app-dir /app/env`` / ``PYTHONPATH``,
21
+ but works for any installed-wheel invocation without requiring a wrapper
22
+ script or Docker.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import os
28
+ import sys
29
+
30
+
31
+ def main() -> None:
32
+ # __file__ resolves to site-packages/sql_drift_env/_cli.py after
33
+ # installation, so its parent IS the directory that contains models.py,
34
+ # server/, actors/, etc.
35
+ _pkg_dir = os.path.dirname(os.path.abspath(__file__))
36
+ if _pkg_dir not in sys.path:
37
+ sys.path.insert(0, _pkg_dir)
38
+
39
+ # Import lazily so the sys.path fix takes effect before any flat import
40
+ # in server/app.py or its transitive dependencies is attempted.
41
+ from server.app import main as _server_main
42
+
43
+ _server_main()
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
actors/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Non-agent actors (Engineering Manager, DBA Oracle) — P4/P9."""
actors/dba_oracle.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DBA Oracle: per-scenario escalating expert guidance.
2
+
3
+ Feature-flagged (off by default). When ``enable_dba_oracle=True`` is set
4
+ at ``reset(...)`` or ``SQL_DRIFT_ENABLE_DBA_ORACLE=1`` is exported in the
5
+ environment, the ``consult_dba`` tool becomes available; three hint
6
+ tiers per scenario are shipped here, covering all 10 scenarios (6 static
7
+ + 4 drift).
8
+
9
+ Hints escalate in specificity: tier 1 names the likely failure mode and
10
+ the diagnostic to confirm it, tier 2 gives the rewrite invariant, and
11
+ tier 3 is a near-spoiler SQL skeleton. The rubric penalizes each consult
12
+ (consultation rubric), so the agent only wins by consulting if the hint net-reduces
13
+ downstream steps.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ from typing import Final
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Per-scenario 3-tier hint tables
23
+ # ---------------------------------------------------------------------------
24
+
25
+
26
+ _HINTS: Final[dict[str, tuple[str, str, str]]] = {
27
+ "01_correlated_subquery": (
28
+ "[DBA tier 1] The expensive shape is a projection-time correlated subquery: "
29
+ "a COUNT over `orders` for every `users` row. Confirm by spotting "
30
+ "`SELECT COUNT(*) ... WHERE o.user_id = u.id` in the SELECT list or by "
31
+ "checking EXPLAIN for repeated dependent work. Preserve one output row per user.",
32
+ "[DBA tier 2] Aggregate fulfilled orders once by `user_id`, then LEFT JOIN that "
33
+ "small result to `users`. Keep the join outer and wrap the count with "
34
+ "`COALESCE(..., 0)` so users with no fulfilled orders stay in the result.",
35
+ "[DBA tier 3] Use `SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders "
36
+ "FROM users u LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders WHERE "
37
+ "status = 'fulfilled' GROUP BY user_id) c ON c.user_id = u.id ORDER BY u.id`. "
38
+ "Validate the row count equals the number of users.",
39
+ ),
40
+ "02_select_star_join": (
41
+ "[DBA tier 1] The waste is over-projection: the inner three-way join uses "
42
+ "`SELECT *`, including wide product text and order metadata, while the outer "
43
+ "query keeps only `order_id`, product `name`, and `qty`.",
44
+ "[DBA tier 2] Inline the join and project exactly `oi.order_id`, `p.name`, and "
45
+ "`oi.qty`. Keep the products and orders joins plus the filters "
46
+ "`p.category = 'books'` and `oi.qty >= 2`; the wrapper exists only to hide "
47
+ "the star projection.",
48
+ "[DBA tier 3] Rewrite as `SELECT oi.order_id, p.name, oi.qty FROM order_items oi "
49
+ "JOIN products p ON p.id = oi.product_id JOIN orders o ON o.id = oi.order_id "
50
+ "WHERE p.category = 'books' AND oi.qty >= 2 ORDER BY oi.order_id, p.name`.",
51
+ ),
52
+ "03_cartesian_join": (
53
+ "[DBA tier 1] This is an accidental cartesian join. `FROM events e, tenants t` "
54
+ "combined with `t.id = e.tenant_id + 0` prevents the optimizer from seeing a "
55
+ "clean tenant-key join early.",
56
+ "[DBA tier 2] Turn the comma join into an explicit equijoin on the tenant key. "
57
+ "Move only `t.id = e.tenant_id` into `ON`; keep the severity filter in `WHERE` "
58
+ "and preserve grouping by tenant tier.",
59
+ "[DBA tier 3] Use `SELECT t.tier, COUNT(*) AS n FROM events e JOIN tenants t "
60
+ "ON t.id = e.tenant_id WHERE e.severity IN ('error', 'critical') GROUP BY "
61
+ "t.tier ORDER BY t.tier`. Avoid arithmetic on the join key.",
62
+ ),
63
+ "04_distinct_groupby": (
64
+ "[DBA tier 1] The duplicate-removal work is redundant. `GROUP BY session_id, path` "
65
+ "already emits one row per `(session_id, path)` pair, so a leading `DISTINCT` "
66
+ "adds a second deduplication pass over grouped rows.",
67
+ "[DBA tier 2] Do not introduce a CTE or change the aggregation grain. Remove only "
68
+ "`DISTINCT`; keep `COUNT(*) AS hits`, the same GROUP BY keys, and the same "
69
+ "ordering so row identity and sort order stay stable.",
70
+ "[DBA tier 3] The target shape is `SELECT session_id, path, COUNT(*) AS hits "
71
+ "FROM pageviews GROUP BY session_id, path ORDER BY session_id, path`. Validate "
72
+ "against the baseline result before comparing runtime.",
73
+ ),
74
+ "05_nested_subquery": (
75
+ "[DBA tier 1] The nested `IN` clauses express a semi-join: authors who wrote "
76
+ "comments on published articles. The important identity is `comments.author_id`, "
77
+ "not `articles.author_id`.",
78
+ "[DBA tier 2] Flatten to `authors -> comments -> articles`, filter "
79
+ "`articles.status = 'published'`, and select distinct author display names. "
80
+ "`DISTINCT` is required here because one author can have many qualifying comments.",
81
+ "[DBA tier 3] Use `SELECT DISTINCT a.display_name FROM authors a JOIN comments c "
82
+ "ON c.author_id = a.id JOIN articles ar ON ar.id = c.article_id WHERE "
83
+ "ar.status = 'published' ORDER BY a.display_name`.",
84
+ ),
85
+ "06_having_as_where": (
86
+ "[DBA tier 1] `status = 'fulfilled'` is a row-level predicate sitting in HAVING, "
87
+ "so the engine groups every status first and discards most groups afterward. "
88
+ "Only `SUM(amount_cents) >= 100000` truly belongs after aggregation.",
89
+ "[DBA tier 2] Move the status filter into `WHERE` before the GROUP BY. Keep "
90
+ "`status` in the projection and grouping to preserve the result shape, then "
91
+ "leave the aggregate threshold in HAVING.",
92
+ "[DBA tier 3] Use `SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
93
+ "FROM orders WHERE status = 'fulfilled' GROUP BY tenant_id, status HAVING "
94
+ "SUM(amount_cents) >= 100000 ORDER BY tenant_id`.",
95
+ ),
96
+ "07_drift_column_rename": (
97
+ "[DBA tier 1] If the old aggregation now fails with an unknown `user_id`, this is "
98
+ "schema drift rather than a performance issue. Read the changelog or describe "
99
+ "`orders`; `users.id` is unchanged, but the order-owner column moved.",
100
+ "[DBA tier 2] Replace every reference to `orders.user_id` with `orders.account_id` "
101
+ "in SELECT, GROUP BY, JOIN, and ORDER BY positions. Do not change the aggregate "
102
+ "logic; the rename preserves row semantics.",
103
+ "[DBA tier 3] Submit `SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) "
104
+ "AS total FROM orders GROUP BY account_id ORDER BY account_id`. Validate that "
105
+ "counts and totals match the pre-drift business result.",
106
+ ),
107
+ "08_drift_date_format": (
108
+ "[DBA tier 1] The `events.ts` identifier still exists, but its type changed from "
109
+ "ISO text to BIGINT epoch milliseconds. A string date predicate can parse or "
110
+ "compare incorrectly; confirm with `describe_table('events')` and samples.",
111
+ "[DBA tier 2] Keep the same half-open UTC day window, but express both bounds as "
112
+ "epoch-ms integers. For 2026-04-21T00:00:00Z through the next midnight, use "
113
+ "`1776729600000 <= ts < 1776816000000`.",
114
+ "[DBA tier 3] Use `SELECT kind, COUNT(*) AS n FROM events WHERE ts >= "
115
+ "1776729600000 AND ts < 1776816000000 GROUP BY kind ORDER BY kind`. Do not quote "
116
+ "the bounds; they must be numeric comparisons against the BIGINT column.",
117
+ ),
118
+ "09_drift_enum_rule": (
119
+ "[DBA tier 1] A formerly valid equality on `status = 'active'` now silently loses "
120
+ "rows because the business state was split into multiple stored labels. Sample "
121
+ "`tenants.status` before assuming the old lowercase value still exists.",
122
+ '[DBA tier 2] Preserve the business meaning "active tenants" by filtering on the '
123
+ "union of replacement labels. Keep the same grouping by tier and ordering; only "
124
+ "the status predicate changes.",
125
+ "[DBA tier 3] Use `SELECT tier, COUNT(*) AS n FROM tenants WHERE status IN "
126
+ "('ACTIVE', 'ACTIVE_V2') GROUP BY tier ORDER BY tier`. Avoid `LOWER(status) = "
127
+ "'active'`; it misses `ACTIVE_V2`.",
128
+ ),
129
+ "10_drift_field_deprecation": (
130
+ "[DBA tier 1] The inline `posts.author_name` column was normalized away. Describe "
131
+ "`posts` and list tables: you should see `posts.users_id` plus a new `users` "
132
+ "lookup carrying the human-readable name.",
133
+ "[DBA tier 2] Join `posts` to `users` through the new FK, group by `u.full_name`, "
134
+ "and alias it back to `author_name` so the result keeps the old report shape. "
135
+ "The post count still comes from `posts`.",
136
+ "[DBA tier 3] Use `SELECT u.full_name AS author_name, COUNT(*) AS n_posts FROM "
137
+ "posts p JOIN users u ON u.id = p.users_id GROUP BY u.full_name ORDER BY "
138
+ "u.full_name`.",
139
+ ),
140
+ }
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # Public API
145
+ # ---------------------------------------------------------------------------
146
+
147
+
148
+ def is_enabled(
149
+ reset_flag: bool | None = None, *, env_var: str = "SQL_DRIFT_ENABLE_DBA_ORACLE"
150
+ ) -> bool:
151
+ """Resolve the feature flag from (reset kwarg, env var, default-off)."""
152
+ if reset_flag is not None:
153
+ return bool(reset_flag)
154
+ raw = os.environ.get(env_var, "").strip().lower()
155
+ return raw in ("1", "true", "yes", "on")
156
+
157
+
158
+ def get_hint(scenario_id: str, tier: int) -> str:
159
+ """Return the hint for ``(scenario_id, tier)``; clamps tier to [1, 3].
160
+
161
+ Raises :class:`KeyError` on unknown scenario so tests can detect when
162
+ a new scenario was added without a hint table entry.
163
+ """
164
+ if scenario_id not in _HINTS:
165
+ raise KeyError(f"no DBA hints for scenario_id={scenario_id!r}; known: {sorted(_HINTS)}")
166
+ tier = max(1, min(3, int(tier)))
167
+ return _HINTS[scenario_id][tier - 1]
168
+
169
+
170
+ def has_hints(scenario_id: str) -> bool:
171
+ return scenario_id in _HINTS
172
+
173
+
174
+ def known_scenarios() -> frozenset[str]:
175
+ return frozenset(_HINTS)
176
+
177
+
178
+ __all__ = [
179
+ "get_hint",
180
+ "has_hints",
181
+ "is_enabled",
182
+ "known_scenarios",
183
+ ]
actors/engineering_manager.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Engineering-manager narrator — authors deterministic drift runbooks.
2
+
3
+ Consumed by the ``read_changelog`` tool. Output is deterministic per
4
+ ``DriftConfig`` so tests can assert it character-for-character and the
5
+ agent can learn to parse drift kinds from the text.
6
+
7
+ Entries are concise migration notes: what changed, why the old query
8
+ breaks, how to adapt, and what to validate before submission.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING
14
+
15
+ if TYPE_CHECKING:
16
+ from scenarios.base import DriftConfig
17
+
18
+
19
+ def _sentence(prefix: str, body: str) -> str:
20
+ return f"[changelog] {prefix} {body}".strip()
21
+
22
+
23
+ def author_changelog(drift_config: DriftConfig) -> str:
24
+ kind = drift_config.kind
25
+ p = drift_config.payload
26
+ if kind == "column_rename":
27
+ return _sentence(
28
+ "Schema change -",
29
+ (
30
+ f"`{p['table']}.{p['old']}` was renamed to `{p['table']}.{p['new']}` "
31
+ "as part of a naming normalization. Impact: queries that SELECT, JOIN, "
32
+ "GROUP BY, ORDER BY, or filter on the old identifier now fail at bind "
33
+ "time. Migration: use the new column everywhere the old one appeared; "
34
+ "related table primary keys are unchanged. Validate that counts and "
35
+ "aggregates are unchanged after the rename."
36
+ ),
37
+ )
38
+ if kind == "date_format":
39
+ return _sentence(
40
+ "Format change -",
41
+ (
42
+ f"`{p['table']}.{p['col']}` no longer stores ISO-8601 strings; it now "
43
+ "stores BIGINT milliseconds since the Unix epoch. Impact: quoted "
44
+ "timestamp literals and text comparisons no longer express the right "
45
+ "predicate. Migration: keep the same half-open time window, but compare "
46
+ "against numeric epoch-ms bounds. Validate with describe_table and a "
47
+ "small sample before submitting."
48
+ ),
49
+ )
50
+ if kind == "enum_rule":
51
+ new_values = ", ".join(f"`{v}`" for v in p["new_values"])
52
+ return _sentence(
53
+ "Business-rule change -",
54
+ (
55
+ f"`{p['table']}.{p['col']}` value `{p['old_value']}` has been split "
56
+ f"into {new_values}. Impact: equality predicates on the old value "
57
+ "silently undercount after the deploy. Migration: replace the single "
58
+ "value predicate with an IN predicate over every replacement label. "
59
+ "Validate by sampling the enum distribution and preserving the existing "
60
+ "GROUP BY and projection."
61
+ ),
62
+ )
63
+ if kind == "field_deprecation":
64
+ orig_t, orig_c = p["orig"]
65
+ lt, lid, lname = p["lookup"]
66
+ fk_col = f"{lt}_{lid}"
67
+ return _sentence(
68
+ "Deprecation -",
69
+ (
70
+ f"`{orig_t}.{orig_c}` was deprecated and replaced by `{orig_t}.{fk_col}` "
71
+ f"pointing at `{lt}.{lid}`; the display value now lives on `{lt}.{lname}`. "
72
+ "Impact: projections or groups over the old inline string fail after "
73
+ "drift. Migration: join through the lookup table, project the display "
74
+ "column, and keep the old output alias if callers expect it. Validate "
75
+ "row counts after the join."
76
+ ),
77
+ )
78
+ raise ValueError(f"unknown drift kind={kind!r}")
79
+
80
+
81
+ __all__ = ["author_changelog"]
client.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLDrift ``EnvClient`` — tool-aware payload constructors + response parser.
2
+
3
+ Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts,
4
+ notebook exploration, and integration tests all use the same WS-backed
5
+ session semantics. Stateful episodes MUST go through the ``/ws`` channel
6
+ (HTTP ``/step`` is stateless: one fresh env per request).
7
+
8
+ Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.)
9
+ hide the discriminated-union boilerplate so agent code reads naturally::
10
+
11
+ env = SqlDriftEnv(base_url="http://localhost:8000").sync()
12
+ with env:
13
+ r = env.reset(seed=42, scenario_id="03_cartesian_join")
14
+ r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events"))
15
+ ...
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Any
21
+
22
+ from openenv.core.client_types import StepResult
23
+ from openenv.core.env_client import EnvClient
24
+
25
+ from models import (
26
+ ConsultDBAPayload,
27
+ DescribeTablePayload,
28
+ ExplainQueryPayload,
29
+ ListTablesPayload,
30
+ ReadChangelogPayload,
31
+ RunQueryPayload,
32
+ SampleRowsPayload,
33
+ SqlDriftAction,
34
+ SqlDriftObservation,
35
+ SqlDriftState,
36
+ SubmitRewritePayload,
37
+ ToolName,
38
+ )
39
+
40
+
41
+ class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
42
+ """Tool-aware client for the SQLDrift OpenEnv environment."""
43
+
44
+ # ------------------------------------------------------------------
45
+ # EnvClient ABC implementations
46
+ # ------------------------------------------------------------------
47
+
48
+ def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]:
49
+ return action.model_dump(mode="json")
50
+
51
+ def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]:
52
+ obs_data = payload.get("observation", {})
53
+ observation = SqlDriftObservation.model_validate(obs_data)
54
+ # Base transport strips reward + done off the observation dict — we
55
+ # re-populate them so the agent can read straight off `.observation`.
56
+ reward = payload.get("reward")
57
+ done = bool(payload.get("done", False))
58
+ observation.reward = reward
59
+ observation.done = done
60
+ return StepResult(observation=observation, reward=reward, done=done)
61
+
62
+ def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState:
63
+ return SqlDriftState.model_validate(payload)
64
+
65
+ # ------------------------------------------------------------------
66
+ # Action factories — one per tool, accepting only the args that tool
67
+ # cares about; payload.kind is filled in automatically.
68
+ # ------------------------------------------------------------------
69
+
70
+ @staticmethod
71
+ def action_list_tables() -> SqlDriftAction:
72
+ return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
73
+
74
+ @staticmethod
75
+ def action_describe_table(table: str) -> SqlDriftAction:
76
+ return SqlDriftAction(
77
+ tool=ToolName.DESCRIBE_TABLE,
78
+ payload=DescribeTablePayload(table=table),
79
+ )
80
+
81
+ @staticmethod
82
+ def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction:
83
+ return SqlDriftAction(
84
+ tool=ToolName.SAMPLE_ROWS,
85
+ payload=SampleRowsPayload(table=table, limit=limit),
86
+ )
87
+
88
+ @staticmethod
89
+ def action_run_query(sql: str) -> SqlDriftAction:
90
+ return SqlDriftAction(
91
+ tool=ToolName.RUN_QUERY,
92
+ payload=RunQueryPayload(sql=sql),
93
+ )
94
+
95
+ @staticmethod
96
+ def action_explain_query(sql: str) -> SqlDriftAction:
97
+ return SqlDriftAction(
98
+ tool=ToolName.EXPLAIN_QUERY,
99
+ payload=ExplainQueryPayload(sql=sql),
100
+ )
101
+
102
+ @staticmethod
103
+ def action_read_changelog() -> SqlDriftAction:
104
+ return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload())
105
+
106
+ @staticmethod
107
+ def action_submit_rewrite(sql: str) -> SqlDriftAction:
108
+ return SqlDriftAction(
109
+ tool=ToolName.SUBMIT_REWRITE,
110
+ payload=SubmitRewritePayload(sql=sql),
111
+ )
112
+
113
+ @staticmethod
114
+ def action_consult_dba(question: str) -> SqlDriftAction:
115
+ return SqlDriftAction(
116
+ tool=ToolName.CONSULT_DBA,
117
+ payload=ConsultDBAPayload(question=question),
118
+ )
119
+
120
+
121
+ __all__ = ["SqlDriftEnv"]
engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Deterministic verifier / profiler / drift / reward engine."""
engine/drift.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Drift engine: four atomic, idempotent DDL operations.
2
+
3
+ Each apply_* function mutates ``conn`` in place inside a DuckDB
4
+ ``BEGIN; ... COMMIT`` pair and returns a machine-readable changelog
5
+ string. Humans consume the string via the :class:`read_changelog`
6
+ tool; the rubric consults a separate drift-acknowledgement flag on the
7
+ runtime state, not the string itself.
8
+
9
+ Idempotency is enforced via a post-condition schema probe: once the drift
10
+ has been applied (the target column / enum value is in the expected
11
+ post-state), a second call short-circuits with the same changelog string.
12
+ This matters because the environment's drift-trigger check runs every
13
+ step and needs to be safe to retry.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ if TYPE_CHECKING:
21
+ import duckdb
22
+
23
+
24
+ # DuckDB auto-commits DDL and forbids mixing multi-statement transactions
25
+ # with schema alterations across commit boundaries. Each drift operation
26
+ # therefore executes its statements sequentially on the default
27
+ # auto-commit connection; individual DML statements (UPDATEs) are
28
+ # internally atomic at the statement level, which is sufficient for the
29
+ # fixture mutation the env needs. If a drift operation raises mid-way we
30
+ # tear down and re-seed the DuckDB via ScenarioSpec.materialize — there's
31
+ # no long-lived on-disk state to roll back.
32
+
33
+
34
+ def _table_columns(conn: duckdb.DuckDBPyConnection, table: str) -> list[str]:
35
+ rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall()
36
+ # PRAGMA table_info returns (cid, name, type, notnull, dflt_value, pk)
37
+ return [r[1] for r in rows]
38
+
39
+
40
+ def _table_exists(conn: duckdb.DuckDBPyConnection, table: str) -> bool:
41
+ rows = conn.execute(
42
+ "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", [table]
43
+ ).fetchone()
44
+ return bool(rows and rows[0])
45
+
46
+
47
+ # =============================================================================
48
+ # Column rename
49
+ # =============================================================================
50
+
51
+
52
+ def apply_column_rename(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
53
+ """``{"table": str, "old": str, "new": str}``."""
54
+ table = payload["table"]
55
+ old = payload["old"]
56
+ new = payload["new"]
57
+
58
+ cols = _table_columns(conn, table)
59
+ if new in cols and old not in cols:
60
+ return f"rename_already_applied:{table}.{old}->{new}"
61
+ if old not in cols:
62
+ raise ValueError(f"column_rename: {table}.{old} missing (cols={cols})")
63
+
64
+ conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{old}" TO "{new}"')
65
+ return f"rename:{table}.{old}->{new}"
66
+
67
+
68
+ # =============================================================================
69
+ # Date format change (iso_string → epoch_ms)
70
+ # =============================================================================
71
+
72
+
73
+ def apply_date_format_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
74
+ """``{"table": str, "col": str, "from": "iso_string", "to": "epoch_ms"}``.
75
+
76
+ Only the one direction is supported for now; the payload still carries
77
+ from/to for forward-compatibility and audit.
78
+ """
79
+ table = payload["table"]
80
+ col = payload["col"]
81
+ from_fmt = payload.get("from", "iso_string")
82
+ to_fmt = payload.get("to", "epoch_ms")
83
+ if (from_fmt, to_fmt) != ("iso_string", "epoch_ms"):
84
+ raise NotImplementedError(
85
+ f"date_format_change only supports iso_string→epoch_ms, got {from_fmt}→{to_fmt}"
86
+ )
87
+
88
+ cols = _table_columns(conn, table)
89
+ # Idempotent: once column is BIGINT, consider it applied.
90
+ type_row = conn.execute(
91
+ "SELECT data_type FROM information_schema.columns WHERE table_name = ? AND column_name = ?",
92
+ [table, col],
93
+ ).fetchone()
94
+ if type_row is None:
95
+ raise ValueError(f"date_format_change: {table}.{col} missing (cols={cols})")
96
+ if "BIGINT" in type_row[0].upper() or "INT" in type_row[0].upper():
97
+ return f"date_format_already_applied:{table}.{col}"
98
+
99
+ tmp = f"{col}_epoch_ms"
100
+ conn.execute(f'ALTER TABLE "{table}" ADD COLUMN "{tmp}" BIGINT')
101
+ conn.execute(
102
+ f'UPDATE "{table}" SET "{tmp}" = '
103
+ f'CAST(EXTRACT(EPOCH FROM CAST("{col}" AS TIMESTAMP)) * 1000 AS BIGINT)'
104
+ )
105
+ conn.execute(f'ALTER TABLE "{table}" DROP COLUMN "{col}"')
106
+ conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{tmp}" TO "{col}"')
107
+
108
+ return f"date_format:{table}.{col}:iso_string->epoch_ms"
109
+
110
+
111
+ # =============================================================================
112
+ # Enum rule change (split `old_value` into N new values)
113
+ # =============================================================================
114
+
115
+
116
+ def apply_enum_rule_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
117
+ """``{"table": str, "col": str, "old_value": str, "new_values": list[str]}``.
118
+
119
+ Rows holding ``old_value`` are re-distributed deterministically into
120
+ ``new_values`` (round-robin by rowid) so the split is reproducible.
121
+ """
122
+ table = payload["table"]
123
+ col = payload["col"]
124
+ old_value = payload["old_value"]
125
+ new_values: list[str] = list(payload["new_values"])
126
+ if not new_values:
127
+ raise ValueError("enum_rule_change: new_values must be non-empty")
128
+
129
+ count_row = conn.execute(
130
+ f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" = ?', [old_value]
131
+ ).fetchone()
132
+ count_old = count_row[0] if count_row is not None else 0
133
+ # Idempotent: if old_value has already been drained AND any of the
134
+ # new_values is present, treat as applied.
135
+ if count_old == 0:
136
+ has_new_row = conn.execute(
137
+ f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" IN ({",".join("?" * len(new_values))})',
138
+ new_values,
139
+ ).fetchone()
140
+ has_new = has_new_row[0] if has_new_row is not None else 0
141
+ if has_new > 0:
142
+ return f"enum_rule_already_applied:{table}.{col}:{old_value}->{new_values}"
143
+
144
+ # Deterministic split by rowid mod N.
145
+ case_branches = " ".join(
146
+ f"WHEN mod(rid, {len(new_values)}) = {i} THEN '{v}'" for i, v in enumerate(new_values)
147
+ )
148
+ conn.execute(
149
+ f"CREATE TEMP TABLE _enum_remap AS "
150
+ f"SELECT rowid AS rid, "
151
+ f"CASE {case_branches} END AS new_val "
152
+ f'FROM "{table}" WHERE "{col}" = ?',
153
+ [old_value],
154
+ )
155
+ conn.execute(
156
+ f'UPDATE "{table}" SET "{col}" = _enum_remap.new_val '
157
+ f'FROM _enum_remap WHERE _enum_remap.rid = "{table}".rowid'
158
+ )
159
+ conn.execute("DROP TABLE _enum_remap")
160
+
161
+ return f"enum_rule:{table}.{col}:{old_value}->{'+'.join(new_values)}"
162
+
163
+
164
+ # =============================================================================
165
+ # Field deprecation (replace inline string col with FK lookup)
166
+ # =============================================================================
167
+
168
+
169
+ def apply_field_deprecation(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str:
170
+ """``{"orig": (table, col), "lookup": (table, id_col, name_col)}``.
171
+
172
+ - Creates the lookup table (if missing) and seeds it with distinct values
173
+ observed on ``orig.col``.
174
+ - Adds ``orig.<lookup_id>`` with a FK-style backfill.
175
+ - Drops ``orig.col``.
176
+ """
177
+ orig_table, orig_col = payload["orig"]
178
+ lookup_table, lookup_id_col, lookup_name_col = payload["lookup"]
179
+ new_fk_col = f"{lookup_table}_{lookup_id_col}" # e.g. "users_id"
180
+
181
+ orig_cols = _table_columns(conn, orig_table)
182
+ if orig_col not in orig_cols and new_fk_col in orig_cols:
183
+ return f"field_deprecation_already_applied:{orig_table}.{orig_col}"
184
+ if orig_col not in orig_cols:
185
+ raise ValueError(f"field_deprecation: {orig_table}.{orig_col} missing (cols={orig_cols})")
186
+
187
+ if not _table_exists(conn, lookup_table):
188
+ conn.execute(
189
+ f'CREATE TABLE "{lookup_table}" ('
190
+ f' "{lookup_id_col}" BIGINT PRIMARY KEY,'
191
+ f' "{lookup_name_col}" VARCHAR'
192
+ ");"
193
+ )
194
+ conn.execute(
195
+ f'INSERT INTO "{lookup_table}" ("{lookup_id_col}", "{lookup_name_col}") '
196
+ f"SELECT ROW_NUMBER() OVER (ORDER BY v) + "
197
+ f'COALESCE((SELECT MAX("{lookup_id_col}") FROM "{lookup_table}"), 0), v '
198
+ f'FROM (SELECT DISTINCT "{orig_col}" AS v FROM "{orig_table}") '
199
+ f"WHERE v IS NOT NULL "
200
+ f' AND v NOT IN (SELECT "{lookup_name_col}" FROM "{lookup_table}");'
201
+ )
202
+ conn.execute(f'ALTER TABLE "{orig_table}" ADD COLUMN "{new_fk_col}" BIGINT')
203
+ conn.execute(
204
+ f'UPDATE "{orig_table}" SET "{new_fk_col}" = lookup."{lookup_id_col}" '
205
+ f'FROM "{lookup_table}" lookup '
206
+ f'WHERE lookup."{lookup_name_col}" = "{orig_table}"."{orig_col}"'
207
+ )
208
+ conn.execute(f'ALTER TABLE "{orig_table}" DROP COLUMN "{orig_col}"')
209
+
210
+ return (
211
+ f"field_deprecation:{orig_table}.{orig_col}->"
212
+ f"{orig_table}.{new_fk_col}→{lookup_table}.{lookup_name_col}"
213
+ )
214
+
215
+
216
+ # =============================================================================
217
+ # Dispatcher
218
+ # =============================================================================
219
+
220
+
221
+ DRIFT_HANDLERS = {
222
+ "column_rename": apply_column_rename,
223
+ "date_format": apply_date_format_change,
224
+ "enum_rule": apply_enum_rule_change,
225
+ "field_deprecation": apply_field_deprecation,
226
+ }
227
+
228
+
229
+ def apply_drift(conn: duckdb.DuckDBPyConnection, kind: str, payload: dict[str, Any]) -> str:
230
+ if kind not in DRIFT_HANDLERS:
231
+ raise ValueError(f"unknown drift kind={kind!r}; known: {sorted(DRIFT_HANDLERS)}")
232
+ return DRIFT_HANDLERS[kind](conn, payload)
233
+
234
+
235
+ __all__ = [
236
+ "DRIFT_HANDLERS",
237
+ "apply_column_rename",
238
+ "apply_date_format_change",
239
+ "apply_drift",
240
+ "apply_enum_rule_change",
241
+ "apply_field_deprecation",
242
+ ]
engine/profiler.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query profiling utilities.
2
+
3
+ A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer.
4
+
5
+ * :func:`execute_once_timed` runs a statement exactly once, enforcing a
6
+ hard ``timeout_s`` wall-clock budget. It is the single entry point used
7
+ by the env for agent-provided SQL so the documented query timeout
8
+ cannot be bypassed. An optional ``max_rows`` caps result-set
9
+ materialization — the fetch is aborted as soon as more than
10
+ ``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot
11
+ drive the server OOM before the caller's size check runs.
12
+ * :func:`execute_hash_timed` executes a statement once and hashes its full
13
+ result incrementally via ``fetchmany`` so correctness checks do not have
14
+ to materialize the full row set in Python memory.
15
+ * :func:`median_of_3_warm_ms` performs one untimed warm-up then three
16
+ timed runs and returns the median milliseconds. Used by scenario
17
+ materialization to publish a stable baseline runtime.
18
+
19
+ Both helpers raise :class:`TimeoutError` when a single run exceeds the
20
+ budget; ``duckdb.Error`` propagates unchanged to the caller.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import contextlib
26
+ import os
27
+ import threading
28
+ import time
29
+ from collections.abc import Callable, Iterator
30
+ from dataclasses import dataclass
31
+ from typing import TYPE_CHECKING, Any, cast
32
+
33
+ from engine.verifier import canonical_row_hash
34
+ from utilities.logger import get_module_logger
35
+
36
+ if TYPE_CHECKING:
37
+ import duckdb
38
+
39
+
40
+ DEFAULT_TIMEOUT_S: float = 2.0
41
+ INTERRUPT_GRACE_S: float = 0.25
42
+ # Maximum number of watchdog escalations (leaked threads) tolerated before
43
+ # logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS.
44
+ MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3"))
45
+
46
+ _LOG = get_module_logger(__name__)
47
+ _FETCH_CHUNK_ROWS = 1024
48
+
49
+ # Module-level counter — incremented each time a watchdog thread survives
50
+ # interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe
51
+ # via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count().
52
+ _watchdog_leak_lock: threading.Lock = threading.Lock()
53
+ _watchdog_leaked_count: int = 0
54
+
55
+
56
+ def get_watchdog_leak_count() -> int:
57
+ """Return the cumulative number of watchdog threads that survived interrupt.
58
+
59
+ A non-zero value means at least one DuckDB worker thread was not stopped
60
+ cleanly and is still alive in the background. Production monitoring should
61
+ alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`.
62
+ """
63
+ return _watchdog_leaked_count
64
+
65
+
66
+ class QueryWatchdogEscalationError(RuntimeError):
67
+ """DuckDB worker survived interrupt; the connection is no longer safe."""
68
+
69
+
70
+ @dataclass(frozen=True)
71
+ class TimedResult:
72
+ """Output of :func:`execute_once_timed`.
73
+
74
+ ``columns`` preserves DuckDB's cursor ``description`` order so callers
75
+ can emit a :class:`models.RunQueryResult` without re-executing the
76
+ query just to recover column names.
77
+
78
+ ``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap
79
+ and the query produced strictly more rows than that cap; in that
80
+ case ``rows`` contains exactly ``max_rows + 1`` entries (the
81
+ one-over read that proves overflow). Callers that care about size
82
+ limits must branch on ``truncated`` rather than re-checking
83
+ ``len(rows)`` against their cap.
84
+ """
85
+
86
+ columns: list[str]
87
+ rows: list[tuple[Any, ...]]
88
+ elapsed_ms: float
89
+ truncated: bool = False
90
+
91
+
92
+ def _fetch_capped(
93
+ cursor: duckdb.DuckDBPyConnection,
94
+ max_rows: int,
95
+ ) -> tuple[list[tuple[Any, ...]], bool]:
96
+ """Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany.
97
+
98
+ Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the
99
+ cursor still has unread rows — we stopped on the first row past the
100
+ cap so the caller can signal overflow without materialising the
101
+ rest of a potentially enormous result set.
102
+ """
103
+ # chunk=1024 trades a few extra Python calls for not over-fetching
104
+ # by orders of magnitude when results are modest. The +1 in the
105
+ # final budget is what makes overflow detectable.
106
+ rows: list[tuple[Any, ...]] = []
107
+ budget = max_rows + 1
108
+ while budget > 0:
109
+ batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget))
110
+ if not batch:
111
+ return rows, False
112
+ rows.extend(batch)
113
+ budget -= len(batch)
114
+ return rows, len(rows) > max_rows
115
+
116
+
117
+ def _iter_cursor_rows(
118
+ cursor: duckdb.DuckDBPyConnection,
119
+ ) -> Iterator[tuple[Any, ...]]:
120
+ while True:
121
+ batch = cursor.fetchmany(_FETCH_CHUNK_ROWS)
122
+ if not batch:
123
+ return
124
+ yield from batch
125
+
126
+
127
+ def _run_worker_with_watchdog[T](
128
+ conn: duckdb.DuckDBPyConnection,
129
+ sql: str,
130
+ timeout_s: float,
131
+ worker: Callable[[], T],
132
+ ) -> T:
133
+ result_holder: dict[str, object] = {}
134
+
135
+ def runner() -> None:
136
+ try:
137
+ result_holder["result"] = worker()
138
+ except BaseException as exc: # Must forward all failures from the worker thread.
139
+ result_holder["error"] = exc
140
+
141
+ thread = threading.Thread(target=runner, daemon=True)
142
+ thread.start()
143
+ thread.join(timeout_s)
144
+ if thread.is_alive():
145
+ # DuckDB's interrupt API is connection-scoped and thread-safe;
146
+ # we ask the query to unwind and then wait *unconditionally*
147
+ # for the worker to exit before surfacing the timeout to the
148
+ # caller. If we returned while the thread were still alive, it
149
+ # would retain access to ``conn`` and its result could race
150
+ # future queries on the same connection — a previously
151
+ # observed source of flaky post-timeout behaviour. In practice
152
+ # DuckDB's interrupt releases the worker within a handful of
153
+ # milliseconds; if the engine ever fails to honour interrupt
154
+ # the process will hang here, which is the correct failure
155
+ # mode for a connection whose state is no longer safe to
156
+ # reuse.
157
+ with contextlib.suppress(Exception):
158
+ conn.interrupt()
159
+ thread.join(INTERRUPT_GRACE_S)
160
+ if thread.is_alive():
161
+ global _watchdog_leaked_count
162
+ with _watchdog_leak_lock:
163
+ _watchdog_leaked_count += 1
164
+ leak_count = _watchdog_leaked_count
165
+ log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error
166
+ log_fn(
167
+ "query watchdog failed to stop worker after %.3fs timeout + %.3fs grace"
168
+ " (cumulative leaked threads: %d)",
169
+ timeout_s,
170
+ INTERRUPT_GRACE_S,
171
+ leak_count,
172
+ )
173
+ raise QueryWatchdogEscalationError(
174
+ f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}"
175
+ )
176
+ raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}")
177
+ if "error" in result_holder:
178
+ error = result_holder["error"]
179
+ assert isinstance(error, BaseException)
180
+ raise error
181
+ return cast(T, result_holder["result"])
182
+
183
+
184
+ def _run_with_watchdog(
185
+ conn: duckdb.DuckDBPyConnection,
186
+ sql: str,
187
+ timeout_s: float,
188
+ max_rows: int | None,
189
+ ) -> TimedResult:
190
+ def worker() -> TimedResult:
191
+ start = time.perf_counter_ns()
192
+ cursor = conn.execute(sql)
193
+ columns = [d[0] for d in cursor.description] if cursor.description else []
194
+ if max_rows is None:
195
+ rows = cursor.fetchall()
196
+ truncated = False
197
+ else:
198
+ rows, truncated = _fetch_capped(cursor, max_rows)
199
+ elapsed_ns = time.perf_counter_ns() - start
200
+ return TimedResult(
201
+ columns=columns,
202
+ rows=rows,
203
+ elapsed_ms=elapsed_ns / 1_000_000.0,
204
+ truncated=truncated,
205
+ )
206
+
207
+ result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
208
+ assert isinstance(result, TimedResult)
209
+ return result
210
+
211
+
212
+ def execute_once_timed(
213
+ conn: duckdb.DuckDBPyConnection,
214
+ sql: str,
215
+ *,
216
+ timeout_s: float = DEFAULT_TIMEOUT_S,
217
+ max_rows: int | None = None,
218
+ ) -> tuple[list[tuple[Any, ...]], float]:
219
+ """Single timed execution — returns ``(rows, elapsed_ms)``.
220
+
221
+ Thin wrapper for callers that don't need column metadata or the
222
+ truncation flag.
223
+ """
224
+ res = _run_with_watchdog(conn, sql, timeout_s, max_rows)
225
+ return res.rows, res.elapsed_ms
226
+
227
+
228
+ def execute_once_with_columns(
229
+ conn: duckdb.DuckDBPyConnection,
230
+ sql: str,
231
+ *,
232
+ timeout_s: float = DEFAULT_TIMEOUT_S,
233
+ max_rows: int | None = None,
234
+ ) -> TimedResult:
235
+ """Single timed execution — returns columns + rows + elapsed_ms.
236
+
237
+ When ``max_rows`` is supplied, the fetch aborts at the first row
238
+ past the cap and ``TimedResult.truncated`` is set. The elapsed
239
+ milliseconds in that case reflect the partial scan, not the query's
240
+ would-be completion time — a truncated read is a *hard error* in
241
+ agent-facing code paths, not a performance measurement.
242
+ """
243
+ return _run_with_watchdog(conn, sql, timeout_s, max_rows)
244
+
245
+
246
+ def execute_hash_timed(
247
+ conn: duckdb.DuckDBPyConnection,
248
+ sql: str,
249
+ *,
250
+ timeout_s: float = DEFAULT_TIMEOUT_S,
251
+ ) -> tuple[str, float]:
252
+ """Single timed execution — returns ``(result_hash, elapsed_ms)``.
253
+
254
+ Unlike :func:`execute_once_timed`, this drains the cursor via
255
+ ``fetchmany`` and hashes rows incrementally, so callers can compare a
256
+ large final result to ground truth without materializing the full row
257
+ set in Python memory.
258
+ """
259
+
260
+ def worker() -> tuple[str, float]:
261
+ start = time.perf_counter_ns()
262
+ cursor = conn.execute(sql)
263
+ result_hash = canonical_row_hash(_iter_cursor_rows(cursor))
264
+ elapsed_ns = time.perf_counter_ns() - start
265
+ return result_hash, elapsed_ns / 1_000_000.0
266
+
267
+ result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
268
+ result_hash, elapsed_ms = result
269
+ assert isinstance(result_hash, str)
270
+ assert isinstance(elapsed_ms, float)
271
+ return result_hash, elapsed_ms
272
+
273
+
274
+ def median_of_3_warm_ms(
275
+ conn: duckdb.DuckDBPyConnection,
276
+ sql: str,
277
+ *,
278
+ timeout_s: float = DEFAULT_TIMEOUT_S,
279
+ ) -> float:
280
+ """Warm cache, then median-of-3 timed runs. Returns milliseconds."""
281
+ _run_with_watchdog(conn, sql, timeout_s, None)
282
+ timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)]
283
+ timings.sort()
284
+ return timings[1]
285
+
286
+
287
+ __all__ = [
288
+ "DEFAULT_TIMEOUT_S",
289
+ "INTERRUPT_GRACE_S",
290
+ "MAX_LEAKED_WATCHDOG_THREADS",
291
+ "QueryWatchdogEscalationError",
292
+ "TimedResult",
293
+ "execute_hash_timed",
294
+ "execute_once_timed",
295
+ "execute_once_with_columns",
296
+ "get_watchdog_leak_count",
297
+ "median_of_3_warm_ms",
298
+ ]
engine/reward.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax).
2
+
3
+ Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`):
4
+
5
+ r_correct correctness vs ground-truth hash, gated on ≥ 1.2× speedup
6
+ r_drift bonus/penalty for (not) adapting to post-drift identifiers
7
+ r_speedup tanh-shaped speedup bonus, gated on r_correct > 0
8
+ r_step_tax base step tax plus bounded productive-action rebates
9
+ r_gatekeepers escalating tool-error / repeat-failing / no-op penalties
10
+ r_consult_dba DBA-oracle consult penalties (feature-flagged; 0 when off)
11
+
12
+ All child rubrics share a single ``ctx_provider`` that returns the private
13
+ :class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric
14
+ stateless relative to the environment and makes each component
15
+ individually unit-testable with a synthesized triple
16
+ ``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+ import re
23
+ from collections.abc import Callable
24
+ from typing import TYPE_CHECKING
25
+
26
+ from openenv.core.rubrics import Rubric
27
+
28
+ from models import (
29
+ SqlDriftAction,
30
+ SqlDriftObservation,
31
+ SubmitRewriteResult,
32
+ ToolError,
33
+ ToolName,
34
+ )
35
+
36
+ if TYPE_CHECKING:
37
+ from engine.runtime import RuntimeEpisodeState
38
+
39
+
40
+ # Tunable thresholds pulled out to module level so tests and future
41
+ # curriculum code share a single source of truth.
42
+ SPEEDUP_MIN: float = 1.2
43
+ SPEEDUP_CAP_FOR_INFTY: float = 64.0
44
+ STEP_TAX: float = -0.03
45
+ STEP_REBATE_LIST_TABLES: float = 0.04
46
+ STEP_REBATE_DESCRIBE_TABLE: float = 0.06
47
+ STEP_REBATE_SAMPLE_ROWS: float = 0.05
48
+ STEP_REBATE_RUN_QUERY: float = 0.04
49
+ STEP_REBATE_EXPLAIN_QUERY: float = 0.04
50
+ STEP_REBATE_READ_CHANGELOG: float = 0.08
51
+
52
+ GATE_MALFORMED_TOOL_CALL: float = -0.3
53
+ GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1
54
+ GATE_REPEAT_FAILING_QUERY: float = -0.1
55
+ GATE_BASELINE_VERBATIM: float = -0.2
56
+ _MAX_ESCALATION_STEPS: int = 3
57
+
58
+ CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8)
59
+
60
+
61
+ # =============================================================================
62
+ # Helpers
63
+ # =============================================================================
64
+
65
+
66
+ def canonicalize_sql(sql: str) -> str:
67
+ """Whitespace/case/alias-insensitive canonical form.
68
+
69
+ Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree;
70
+ falls back to a simple whitespace fold if sqlglot rejects the SQL
71
+ (e.g. during the baseline-verbatim check on an agent-submitted blob).
72
+ """
73
+ try:
74
+ import sqlglot
75
+
76
+ expr = sqlglot.parse_one(sql, dialect="duckdb")
77
+ return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower()
78
+ except Exception:
79
+ return " ".join(sql.lower().split())
80
+
81
+
82
+ _IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b")
83
+
84
+
85
+ def _extract_identifiers(sql: str) -> frozenset[str]:
86
+ """Loose case-preserving identifier harvest.
87
+
88
+ Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'`
89
+ does not leak 'ACTIVE' into the identifier set.
90
+ """
91
+ stripped = re.sub(r"'[^']*'", "", sql)
92
+ return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped))
93
+
94
+
95
+ def _extract_column_references(sql: str) -> frozenset[str]:
96
+ """Column-reference identifiers via sqlglot AST (excludes alias labels).
97
+
98
+ For drift-adapt scoring, ``SELECT account_id AS user_id`` references
99
+ the new column and merely labels the output — the alias must not
100
+ count as a surviving pre-drift marker. Falls back to the regex
101
+ extractor on parse failure so malformed SQL still scores something.
102
+ """
103
+ try:
104
+ import sqlglot
105
+ expr = sqlglot.parse_one(sql, dialect="duckdb")
106
+ except Exception:
107
+ return _extract_identifiers(sql)
108
+ if expr is None:
109
+ return _extract_identifiers(sql)
110
+ return frozenset(
111
+ n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name
112
+ )
113
+
114
+
115
+ def _literals(sql: str) -> frozenset[str]:
116
+ """All `'..'`-quoted string literals in `sql`."""
117
+ return frozenset(re.findall(r"'([^']*)'", sql))
118
+
119
+
120
+ _AGENT_MS_EPSILON: float = 1e-6
121
+
122
+
123
+ def effective_speedup(rt: RuntimeEpisodeState) -> float | None:
124
+ """Compute ``effective_speedup`` from the runtime snapshot (speedup rubric).
125
+
126
+ Single source of truth for the speedup number used across the code
127
+ base — rubric scoring, the skill library's ``avg_speedup`` field,
128
+ and the training evaluator all route through here so divergent
129
+ definitions cannot drift apart.
130
+
131
+ Returns:
132
+
133
+ * ``None`` — no submission has happened yet. Callers that need a
134
+ numeric default (e.g. the rubric, which is only invoked
135
+ post-submission) should verify ``rt.submitted`` first.
136
+ * ``+∞`` — drift has fired and the pre-drift baseline SQL no longer
137
+ executes against the post-drift schema; any correct submission is
138
+ definitionally "infinitely faster" than an unrunnable baseline.
139
+ * ``baseline_ms / max(agent_ms, ε)`` otherwise. A tiny ``ε`` clamp
140
+ guards against zero/negative timings from sub-microsecond queries
141
+ and treats them as "as fast as possible" (very large, finite
142
+ speedup) rather than silently collapsing the reward.
143
+ """
144
+ if rt.submitted_runtime_ms is None:
145
+ return None
146
+ if rt.drift_fired and rt.baseline_postdrift_raises:
147
+ return math.inf
148
+ agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON)
149
+ return rt.baseline_runtime_ms / agent_ms
150
+
151
+
152
+ def _speedup_for_reward(rt: RuntimeEpisodeState) -> float:
153
+ """Rubric-facing speedup that never returns ``None``.
154
+
155
+ The rubric is only invoked once ``rt.submitted`` is True, so
156
+ :func:`effective_speedup` cannot return ``None`` from these call
157
+ sites; we assert that and coerce to ``0.0`` defensively if it ever
158
+ does (prevents a silent ``TypeError`` inside the reward math).
159
+ """
160
+ val = effective_speedup(rt)
161
+ return 0.0 if val is None else val
162
+
163
+
164
+ def _is_terminal_submission(
165
+ action: SqlDriftAction,
166
+ observation: SqlDriftObservation,
167
+ rt: RuntimeEpisodeState,
168
+ ) -> bool:
169
+ """True iff this step is the submission step.
170
+
171
+ The env sets ``done=True`` on a successful submission and attaches a
172
+ :class:`SubmitRewriteResult`; we gate terminal rewards on both
173
+ signals so repeated rubric calls on an unchanged state don't
174
+ double-score.
175
+ """
176
+ if not rt.submitted:
177
+ return False
178
+ if action.tool != ToolName.SUBMIT_REWRITE:
179
+ return False
180
+ tr = observation.tool_result
181
+ return isinstance(tr, SubmitRewriteResult)
182
+
183
+
184
+ def _gt_hash(rt: RuntimeEpisodeState) -> str | None:
185
+ if rt.drift_fired and rt.gt_result_hash_postdrift is not None:
186
+ return rt.gt_result_hash_postdrift
187
+ return rt.gt_result_hash_predrift
188
+
189
+
190
+ # =============================================================================
191
+ # Child rubrics
192
+ # =============================================================================
193
+
194
+
195
+ class _CtxChild(Rubric):
196
+ """Base child rubric sharing the ctx provider."""
197
+
198
+ def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
199
+ super().__init__()
200
+ object.__setattr__(self, "_ctx", ctx_provider)
201
+
202
+ def forward(
203
+ self,
204
+ action: SqlDriftAction,
205
+ observation: SqlDriftObservation,
206
+ ) -> float:
207
+ raise NotImplementedError
208
+
209
+
210
+ class Correctness(_CtxChild):
211
+ """Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup."""
212
+
213
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
214
+ rt = self._ctx()
215
+ if not _is_terminal_submission(action, observation, rt):
216
+ return 0.0
217
+ gt = _gt_hash(rt)
218
+ agent_hash = rt.submitted_result_hash
219
+ if gt is None or agent_hash is None:
220
+ return 0.0
221
+ if agent_hash != gt:
222
+ return -1.0
223
+ speedup = _speedup_for_reward(rt)
224
+ if speedup >= SPEEDUP_MIN:
225
+ return 1.0
226
+ return 0.5
227
+
228
+
229
+ class DriftAdapt(_CtxChild):
230
+ """+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only
231
+ submission after drift fired.
232
+
233
+ Adaptation is detected against two scenario-declared identifier sets:
234
+
235
+ * ``postdrift_identifiers`` — identifiers/literals that only a
236
+ correct post-drift rewrite will introduce (e.g. ``account_id``
237
+ after a column rename, ``'ACTIVE'`` after an enum split).
238
+ * ``predrift_identifiers`` — identifiers/literals a submission that
239
+ ignored the drift would retain (e.g. ``user_id``, ``'active'``,
240
+ the ISO anchor strings under date-format drift).
241
+
242
+ A submission is considered "adapted" when it either surfaces a
243
+ post-drift marker *or* the scenario declares no distinctive
244
+ post-drift identifiers (e.g. date-format drift keeps the same
245
+ column name and only the literal shape changes) AND it does not
246
+ retain any pre-drift marker. The penalty fires only when the
247
+ submission still carries pre-drift markers AND produced the wrong
248
+ post-drift result — so a merely partial rewrite (neither pre-
249
+ nor post-flavoured) never earns a penalty it can't diagnose.
250
+ """
251
+
252
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
253
+ rt = self._ctx()
254
+ # Only drift scenarios participate.
255
+ if rt.gt_result_hash_postdrift is None and not rt.drift_fired:
256
+ return 0.0
257
+ if not _is_terminal_submission(action, observation, rt):
258
+ return 0.0
259
+ inst = getattr(rt, "instance", None)
260
+ post_ids: frozenset[str] = (
261
+ getattr(inst, "postdrift_identifiers", frozenset()) or frozenset()
262
+ )
263
+ pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset()
264
+ agent_sql = rt.submitted_sql or ""
265
+
266
+ idents = _extract_column_references(agent_sql)
267
+ literals = _literals(agent_sql)
268
+ markers = idents | literals
269
+ uses_post = bool(post_ids & markers)
270
+ uses_pre = bool(pre_ids & markers)
271
+
272
+ # Treat "no distinctive post identifier" scenarios as
273
+ # satisfied by absence-of-pre (see class docstring).
274
+ adapted = (uses_post or not post_ids) and not uses_pre
275
+
276
+ agent_hash = rt.submitted_result_hash
277
+ gt_post = rt.gt_result_hash_postdrift
278
+
279
+ if rt.drift_fired and agent_hash == gt_post and adapted:
280
+ return 0.5
281
+ if rt.drift_fired and uses_pre and agent_hash != gt_post:
282
+ return -0.5
283
+ return 0.0
284
+
285
+
286
+ class Speedup(_CtxChild):
287
+ """Terminal-only, gated on r_correct > 0: 0.3·tanh(log2(speedup)/3)."""
288
+
289
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
290
+ rt = self._ctx()
291
+ if not _is_terminal_submission(action, observation, rt):
292
+ return 0.0
293
+ gt = _gt_hash(rt)
294
+ if gt is None or rt.submitted_result_hash != gt:
295
+ return 0.0
296
+ raw = _speedup_for_reward(rt)
297
+ if math.isinf(raw):
298
+ raw = SPEEDUP_CAP_FOR_INFTY
299
+ if raw <= 1.0:
300
+ return 0.0
301
+ return 0.3 * math.tanh(math.log2(raw) / 3.0)
302
+
303
+
304
+ class StepTax(_CtxChild):
305
+ """Base step tax plus bounded rebates for productive exploration."""
306
+
307
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
308
+ rt = self._ctx()
309
+ if _is_terminal_submission(action, observation, rt):
310
+ return 0.0
311
+ rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0)))
312
+ return STEP_TAX + rebate
313
+
314
+
315
+ class Gatekeepers(_CtxChild):
316
+ """Sum of three independent penalties; repeats escalate up to a cap."""
317
+
318
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
319
+ rt = self._ctx()
320
+ penalty = 0.0
321
+ # 1. Malformed / failed tool call — ToolError emitted this step.
322
+ if isinstance(observation.tool_result, ToolError):
323
+ penalty += GATE_MALFORMED_TOOL_CALL
324
+ streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1)
325
+ penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS)
326
+ # 2. Repeat failing query — env marks the flag on the runtime
327
+ # state immediately before invoking the rubric.
328
+ repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1)
329
+ if repeats > 0:
330
+ penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS)
331
+ # 3. Baseline-verbatim submission (Rev-3 gate — stacks with
332
+ # correctness's +0.5 partial to cap the no-op rewrite at +0.3).
333
+ if (
334
+ action.tool == ToolName.SUBMIT_REWRITE
335
+ and _is_terminal_submission(action, observation, rt)
336
+ and rt.submitted_sql_canonical == rt.baseline_sql_canonical
337
+ ):
338
+ penalty += GATE_BASELINE_VERBATIM
339
+ return penalty
340
+
341
+
342
+ class ConsultDBA(_CtxChild):
343
+ """Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on."""
344
+
345
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
346
+ rt = self._ctx()
347
+ oracle_enabled = getattr(rt, "dba_oracle_enabled", False)
348
+ if not oracle_enabled:
349
+ return 0.0
350
+ if action.tool != ToolName.CONSULT_DBA:
351
+ return 0.0
352
+ # Count the consult THIS step by indexing into the escalation
353
+ # table using the pre-increment value (env increments on the same step).
354
+ tier = min(rt.consultations_used, len(CONSULT_ESCALATION))
355
+ if tier <= 0:
356
+ return CONSULT_ESCALATION[0]
357
+ return CONSULT_ESCALATION[tier - 1]
358
+
359
+
360
+ # =============================================================================
361
+ # Composite
362
+ # =============================================================================
363
+
364
+
365
+ class SqlDriftRubric(Rubric):
366
+ """Composite rubric: sum of six children.
367
+
368
+ Registration as attributes auto-enrolls them in
369
+ :meth:`Rubric.named_rubrics` so training loops can introspect
370
+ per-component scores.
371
+ """
372
+
373
+ def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
374
+ super().__init__()
375
+ # NOTE: order matters — correctness must populate last_score before
376
+ # speedup reads it via the shared ctx_provider (both are pure
377
+ # functions of the runtime state, so identical output — but the
378
+ # explicit ordering documents the intent).
379
+ self.correctness = Correctness(ctx_provider)
380
+ self.drift_adapt = DriftAdapt(ctx_provider)
381
+ self.speedup = Speedup(ctx_provider)
382
+ self.step_tax = StepTax(ctx_provider)
383
+ self.gatekeepers = Gatekeepers(ctx_provider)
384
+ self.consult_dba = ConsultDBA(ctx_provider)
385
+
386
+ def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
387
+ total = (
388
+ self.correctness(action, observation)
389
+ + self.drift_adapt(action, observation)
390
+ + self.speedup(action, observation)
391
+ + self.step_tax(action, observation)
392
+ + self.gatekeepers(action, observation)
393
+ + self.consult_dba(action, observation)
394
+ )
395
+ return total
396
+
397
+ def component_scores(self) -> dict[str, float]:
398
+ """Return the most-recent per-component scores, keyed for W&B.
399
+
400
+ Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation
401
+ envelope and the demo plots agree on a stable schema.
402
+ """
403
+ return {
404
+ "r_correct": float(self.correctness.last_score or 0.0),
405
+ "r_drift": float(self.drift_adapt.last_score or 0.0),
406
+ "r_speedup": float(self.speedup.last_score or 0.0),
407
+ "r_step_tax": float(self.step_tax.last_score or 0.0),
408
+ "r_gatekeepers": float(self.gatekeepers.last_score or 0.0),
409
+ "r_consult_dba": float(self.consult_dba.last_score or 0.0),
410
+ }
411
+
412
+
413
+ __all__ = [
414
+ "CONSULT_ESCALATION",
415
+ "ConsultDBA",
416
+ "Correctness",
417
+ "DriftAdapt",
418
+ "GATE_BASELINE_VERBATIM",
419
+ "GATE_CONSECUTIVE_TOOL_ERROR",
420
+ "GATE_MALFORMED_TOOL_CALL",
421
+ "GATE_REPEAT_FAILING_QUERY",
422
+ "Gatekeepers",
423
+ "SPEEDUP_CAP_FOR_INFTY",
424
+ "SPEEDUP_MIN",
425
+ "STEP_REBATE_DESCRIBE_TABLE",
426
+ "STEP_REBATE_EXPLAIN_QUERY",
427
+ "STEP_REBATE_LIST_TABLES",
428
+ "STEP_REBATE_READ_CHANGELOG",
429
+ "STEP_REBATE_RUN_QUERY",
430
+ "STEP_REBATE_SAMPLE_ROWS",
431
+ "STEP_TAX",
432
+ "Speedup",
433
+ "SqlDriftRubric",
434
+ "StepTax",
435
+ "canonicalize_sql",
436
+ "effective_speedup",
437
+ ]
engine/runtime.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Private per-episode runtime state (not exposed over the wire).
2
+
3
+ This module is imported by both :mod:`engine.reward` and
4
+ :mod:`server.sql_drift_env_environment` — keeping it out of ``server/``
5
+ avoids the import cycle ``engine.reward → server → engine.reward``.
6
+
7
+ NEVER serialize or expose this over any endpoint. The public state
8
+ projection lives in :class:`models.SqlDriftState`.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import TYPE_CHECKING, Any
15
+
16
+ from models import EpisodePhase
17
+
18
+ if TYPE_CHECKING:
19
+ import duckdb
20
+
21
+
22
+ @dataclass
23
+ class RuntimeEpisodeState:
24
+ """Private per-episode state — holds secrets (DB handle, ground truth)."""
25
+
26
+ episode_id: str
27
+ seed: int
28
+ scenario_id: str
29
+
30
+ instance: Any # scenarios.base.ScenarioInstance — forward ref avoids import cycle
31
+ conn: duckdb.DuckDBPyConnection
32
+
33
+ # Ground truth hashes (never exposed)
34
+ gt_result_hash_predrift: str
35
+ gt_result_hash_postdrift: str | None
36
+
37
+ # Baseline measurements (used by speedup + baseline-verbatim gates)
38
+ baseline_runtime_ms: float
39
+ baseline_tokens: int
40
+ baseline_sql_canonical: str
41
+ baseline_postdrift_raises: bool
42
+
43
+ # Drift timing (scheduled step + cooldown relative to first run_query).
44
+ drift_scheduled_step: int | None
45
+ connection_poisoned: bool = False
46
+ drift_fired_step: int | None = None
47
+ first_run_query_step: int | None = None
48
+
49
+ # Episode progression
50
+ step_count: int = 0
51
+ phase: EpisodePhase = EpisodePhase.DIAGNOSE
52
+ budget_steps: int = 25
53
+
54
+ # Per-step bookkeeping
55
+ failed_query_hashes: set[str] = field(default_factory=set)
56
+ failed_query_counts: dict[str, int] = field(default_factory=dict)
57
+ changelog_entries: list[str] = field(default_factory=list)
58
+ consultations_used: int = 0
59
+ listed_tables_rewarded: bool = False
60
+ described_tables_rewarded: set[str] = field(default_factory=set)
61
+ sampled_tables_rewarded: set[str] = field(default_factory=set)
62
+ run_query_rewarded: bool = False
63
+ explain_query_rewarded: bool = False
64
+ changelog_rewarded_after_drift: bool = False
65
+
66
+ # Phase-machine bookkeeping — counts successful diagnostic tool calls
67
+ # (list_tables, describe_table, sample_rows, run_query, explain_query,
68
+ # read_changelog). The DIAGNOSE → REWRITE transition fires the first
69
+ # time this becomes non-zero; SUBMIT_REWRITE is rejected while this is
70
+ # still zero (ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE).
71
+ diagnostic_actions_taken: int = 0
72
+
73
+ # Submission state — populated once SUBMIT_REWRITE is accepted
74
+ submitted: bool = False
75
+ submitted_sql: str | None = None
76
+ submitted_sql_canonical: str | None = None
77
+ submitted_result_hash: str | None = None
78
+ submitted_runtime_ms: float | None = None
79
+
80
+ # Last-step signal — consumed by the rubric to compute per-step penalties
81
+ last_step_was_tool_error: bool = False
82
+ last_step_was_repeat_failing_query: bool = False
83
+ last_step_repeat_failing_query_count: int = 0
84
+ last_step_productive_rebate: float = 0.0
85
+ consecutive_tool_errors: int = 0
86
+
87
+ # Drift acknowledgement — set True the first time the agent reads the
88
+ # changelog or observes post-drift schema identifiers in a query.
89
+ drift_acknowledged: bool = False
90
+
91
+ # DBA Oracle feature flag (read by the ConsultDBA child rubric).
92
+ # Always False unless explicitly enabled at reset (kwarg or env var).
93
+ dba_oracle_enabled: bool = False
94
+
95
+ @property
96
+ def drift_fired(self) -> bool:
97
+ return self.drift_fired_step is not None
98
+
99
+ @property
100
+ def budget_steps_remaining(self) -> int:
101
+ return max(0, self.budget_steps - self.step_count)
engine/verifier.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic result verification.
2
+
3
+ Canonicalizes floats to `FLOAT_PRECISION` decimal places and treats NULL
4
+ uniformly so that two result sets with the same semantic content hash to
5
+ the same digest regardless of row order, floating-point noise, or None vs
6
+ SQL NULL representation.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ from collections.abc import Iterable
13
+ from typing import Any
14
+
15
+ NULL_SENTINEL = "\x00NULL\x00"
16
+ FLOAT_PRECISION = 6
17
+ _DIGEST_MODULUS = 1 << 256
18
+
19
+
20
+ def _normalize_value(v: Any) -> Any:
21
+ if v is None:
22
+ return NULL_SENTINEL
23
+ if isinstance(v, float):
24
+ # NaN hashes to itself here — a NaN in rows yields a deterministic
25
+ # digest, but two NaNs round-trip to different representations when
26
+ # we `repr`. Guard explicitly.
27
+ if v != v: # NaN
28
+ return NULL_SENTINEL
29
+ return round(v, FLOAT_PRECISION)
30
+ return v
31
+
32
+
33
+ def _row_digest_int(row: Iterable[Any]) -> int:
34
+ normalized = tuple(_normalize_value(v) for v in row)
35
+ digest = hashlib.sha256(repr(normalized).encode()).digest()
36
+ return int.from_bytes(digest, "big", signed=False)
37
+
38
+
39
+ def canonical_row_hash(rows: Iterable[Iterable[Any]]) -> str:
40
+ """Order-independent hash of a result set.
41
+
42
+ This stays order-independent and duplicate-sensitive without
43
+ materializing the full result in memory. Each normalized row is
44
+ hashed once, then folded into three commutative accumulators so the
45
+ final digest is stable across row order and Python processes.
46
+ """
47
+ row_count = 0
48
+ sum_acc = 0
49
+ sumsq_acc = 0
50
+ xor_acc = 0
51
+ for row in rows:
52
+ row_count += 1
53
+ row_digest = _row_digest_int(row)
54
+ sum_acc = (sum_acc + row_digest) % _DIGEST_MODULUS
55
+ sumsq_acc = (sumsq_acc + ((row_digest * row_digest) % _DIGEST_MODULUS)) % _DIGEST_MODULUS
56
+ xor_acc ^= row_digest
57
+ payload = b"".join(
58
+ (
59
+ row_count.to_bytes(32, "big", signed=False),
60
+ sum_acc.to_bytes(32, "big", signed=False),
61
+ sumsq_acc.to_bytes(32, "big", signed=False),
62
+ xor_acc.to_bytes(32, "big", signed=False),
63
+ )
64
+ )
65
+ return hashlib.sha256(payload).hexdigest()
66
+
67
+
68
+ def result_matches(agent_rows: Iterable[Iterable[Any]], gt_hash: str) -> bool:
69
+ """True if `agent_rows` canonicalizes to the ground-truth hash."""
70
+ return canonical_row_hash(agent_rows) == gt_hash
71
+
72
+
73
+ def row_set_jaccard(a: Iterable[Iterable[Any]], b: Iterable[Iterable[Any]]) -> float:
74
+ """Jaccard over normalized row sets (order- and duplicate-insensitive).
75
+
76
+ Each input row is normalised with :func:`_normalize_value` and
77
+ collapsed into a :class:`frozenset`-style Python ``set``, so rows
78
+ that repeat within a single result are counted once. This is
79
+ deliberately *not* a multiset Jaccard — multiset semantics would
80
+ punish correct queries that legitimately emit duplicates more
81
+ harshly than intended.
82
+
83
+ Not used by the lean reward today, but kept covered by tests so
84
+ we can opt in later without rework.
85
+ """
86
+ norm_a = {tuple(_normalize_value(v) for v in row) for row in a}
87
+ norm_b = {tuple(_normalize_value(v) for v in row) for row in b}
88
+ if not norm_a and not norm_b:
89
+ return 1.0
90
+ union = norm_a | norm_b
91
+ inter = norm_a & norm_b
92
+ return len(inter) / len(union)
93
+
94
+
95
+ __all__ = [
96
+ "FLOAT_PRECISION",
97
+ "NULL_SENTINEL",
98
+ "canonical_row_hash",
99
+ "result_matches",
100
+ "row_set_jaccard",
101
+ ]
models.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public data models for SQLDrift.
2
+
3
+ Rev 3 design notes enforced here:
4
+
5
+ - Action is a discriminated union over a public `kind: Literal[...]` tag on
6
+ each payload sub-model. Pydantic v2 forbids leading-underscore names as
7
+ discriminator keys (reserved for private attrs), so we keep the tag public.
8
+ - `SqlDriftAction` cross-validates that the envelope-level `tool` matches
9
+ `payload.kind` (prevents inconsistent envelopes from being constructed).
10
+ - `SqlDriftObservation.tool_result` is itself a discriminated union over the
11
+ eight concrete result types plus `ToolError` (for in-env semantic failures;
12
+ envelope-level `ValidationError` is a transport-layer concern, not an in-env code).
13
+ - `SqlDriftState` is the public state snapshot shipped over `/state`. It
14
+ never carries ground truth, DB handles, baseline runtime, or seeds;
15
+ `extra="forbid"` guarantees no accidental leak as new fields are added.
16
+ The private `RuntimeEpisodeState` lives in :mod:`engine.runtime`.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from enum import StrEnum
22
+ from typing import Annotated, Any, Literal
23
+
24
+ from openenv.core.env_server.types import Action, Observation, State
25
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
26
+ from pydantic_core import PydanticCustomError
27
+
28
+ # =============================================================================
29
+ # Enums
30
+ # =============================================================================
31
+
32
+
33
+ class ToolName(StrEnum):
34
+ LIST_TABLES = "list_tables"
35
+ DESCRIBE_TABLE = "describe_table"
36
+ SAMPLE_ROWS = "sample_rows"
37
+ RUN_QUERY = "run_query"
38
+ EXPLAIN_QUERY = "explain_query"
39
+ READ_CHANGELOG = "read_changelog"
40
+ SUBMIT_REWRITE = "submit_rewrite"
41
+ CONSULT_DBA = "consult_dba"
42
+
43
+
44
+ class EpisodePhase(StrEnum):
45
+ DIAGNOSE = "diagnose"
46
+ REWRITE = "rewrite"
47
+ DRIFT_RECOVERY = "drift_recovery"
48
+ FINALIZE = "finalize"
49
+
50
+
51
+ class ToolErrorCode(StrEnum):
52
+ """In-environment semantic failure codes (API contract).
53
+
54
+ Envelope-level `pydantic.ValidationError` is handled by the OpenEnv
55
+ transport layer (HTTP 422 / `/ws` error frame) and never reaches
56
+ `env.step`, so it has no code here.
57
+ """
58
+
59
+ DB_ERROR = "db_error"
60
+ UNKNOWN_TABLE = "unknown_table"
61
+ QUERY_TIMEOUT = "query_timeout"
62
+ RESULT_TOO_LARGE = "result_too_large"
63
+ SUBMIT_BEFORE_DIAGNOSE = "submit_before_diagnose"
64
+ INVALID_TOOL_ARGUMENT = "invalid_tool_argument"
65
+
66
+
67
+ # =============================================================================
68
+ # Tool payloads (request side of `SqlDriftAction`)
69
+ # =============================================================================
70
+
71
+
72
+ class _BasePayload(BaseModel):
73
+ """Shared config for every tool-call payload."""
74
+
75
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
76
+
77
+
78
+ class ListTablesPayload(_BasePayload):
79
+ kind: Literal["list_tables"] = "list_tables"
80
+
81
+
82
+ class DescribeTablePayload(_BasePayload):
83
+ kind: Literal["describe_table"] = "describe_table"
84
+ table: str = Field(min_length=1, max_length=63)
85
+
86
+
87
+ class SampleRowsPayload(_BasePayload):
88
+ kind: Literal["sample_rows"] = "sample_rows"
89
+ table: str = Field(min_length=1, max_length=63)
90
+ limit: int = Field(default=5, ge=1, le=5)
91
+
92
+
93
+ class RunQueryPayload(_BasePayload):
94
+ kind: Literal["run_query"] = "run_query"
95
+ sql: str = Field(min_length=1, max_length=10_000)
96
+
97
+
98
+ class ExplainQueryPayload(_BasePayload):
99
+ kind: Literal["explain_query"] = "explain_query"
100
+ sql: str = Field(min_length=1, max_length=10_000)
101
+
102
+
103
+ class ReadChangelogPayload(_BasePayload):
104
+ kind: Literal["read_changelog"] = "read_changelog"
105
+
106
+
107
+ class SubmitRewritePayload(_BasePayload):
108
+ kind: Literal["submit_rewrite"] = "submit_rewrite"
109
+ sql: str = Field(min_length=1, max_length=10_000)
110
+
111
+
112
+ class ConsultDBAPayload(_BasePayload):
113
+ kind: Literal["consult_dba"] = "consult_dba"
114
+ question: str = Field(min_length=1, max_length=400)
115
+
116
+
117
+ ToolPayload = Annotated[
118
+ ListTablesPayload
119
+ | DescribeTablePayload
120
+ | SampleRowsPayload
121
+ | RunQueryPayload
122
+ | ExplainQueryPayload
123
+ | ReadChangelogPayload
124
+ | SubmitRewritePayload
125
+ | ConsultDBAPayload,
126
+ Field(discriminator="kind"),
127
+ ]
128
+
129
+
130
+ # Tool -> payload-kind mapping; single source of truth for cross-validation
131
+ # and for the server-side dispatcher in P7.
132
+ TOOL_TO_PAYLOAD_KIND: dict[ToolName, str] = {
133
+ ToolName.LIST_TABLES: "list_tables",
134
+ ToolName.DESCRIBE_TABLE: "describe_table",
135
+ ToolName.SAMPLE_ROWS: "sample_rows",
136
+ ToolName.RUN_QUERY: "run_query",
137
+ ToolName.EXPLAIN_QUERY: "explain_query",
138
+ ToolName.READ_CHANGELOG: "read_changelog",
139
+ ToolName.SUBMIT_REWRITE: "submit_rewrite",
140
+ ToolName.CONSULT_DBA: "consult_dba",
141
+ }
142
+
143
+
144
+ # =============================================================================
145
+ # SqlDriftAction envelope
146
+ # =============================================================================
147
+
148
+
149
+ class SqlDriftAction(Action):
150
+ """Tool-call envelope.
151
+
152
+ JSON wire format::
153
+
154
+ {"tool": "run_query", "payload": {"kind": "run_query", "sql": "..."}}
155
+
156
+ The `tool` field and `payload.kind` must agree; mismatch raises at
157
+ validation time.
158
+ """
159
+
160
+ tool: ToolName
161
+ payload: ToolPayload
162
+
163
+ @model_validator(mode="after")
164
+ def _tool_matches_payload(self) -> SqlDriftAction:
165
+ expected = TOOL_TO_PAYLOAD_KIND[self.tool]
166
+ if self.payload.kind != expected:
167
+ # PydanticCustomError keeps ``ctx`` JSON-serializable (plain
168
+ # strings only), unlike a bare ``ValueError`` which Pydantic
169
+ # wraps with ``ctx={"error": ValueError(...)}`` and breaks
170
+ # FastAPI HTTPException JSON encoder (422 responses).
171
+ raise PydanticCustomError(
172
+ "tool_payload_mismatch",
173
+ "tool/payload mismatch: tool={tool} expects payload.kind={expected}, got {got}",
174
+ {
175
+ "tool": self.tool.value,
176
+ "expected": expected,
177
+ "got": self.payload.kind,
178
+ },
179
+ )
180
+ return self
181
+
182
+
183
+ # =============================================================================
184
+ # Tool results (response side of `SqlDriftObservation.tool_result`)
185
+ # =============================================================================
186
+
187
+
188
+ class _BaseResult(BaseModel):
189
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
190
+
191
+
192
+ class ListTablesResult(_BaseResult):
193
+ kind: Literal["list_tables_result"] = "list_tables_result"
194
+ tables: list[str]
195
+
196
+
197
+ class DescribeTableResult(_BaseResult):
198
+ kind: Literal["describe_table_result"] = "describe_table_result"
199
+ table: str
200
+ columns: list[dict[str, str]] # [{"name": "...", "type": "..."}]
201
+
202
+
203
+ class SampleRowsResult(_BaseResult):
204
+ kind: Literal["sample_rows_result"] = "sample_rows_result"
205
+ table: str
206
+ columns: list[str]
207
+ rows: list[list[Any]]
208
+
209
+
210
+ class RunQueryResult(_BaseResult):
211
+ kind: Literal["run_query_result"] = "run_query_result"
212
+ columns: list[str]
213
+ rows: list[list[Any]]
214
+ runtime_ms: float
215
+ row_count: int
216
+
217
+
218
+ class ExplainQueryResult(_BaseResult):
219
+ kind: Literal["explain_query_result"] = "explain_query_result"
220
+ plan: str
221
+
222
+
223
+ class ReadChangelogResult(_BaseResult):
224
+ kind: Literal["read_changelog_result"] = "read_changelog_result"
225
+ entries: list[str]
226
+
227
+
228
+ class SubmitRewriteResult(_BaseResult):
229
+ kind: Literal["submit_rewrite_result"] = "submit_rewrite_result"
230
+ accepted: bool
231
+ runtime_ms: float
232
+ matches_ground_truth: bool
233
+
234
+
235
+ class ConsultDBAResult(_BaseResult):
236
+ kind: Literal["consult_dba_result"] = "consult_dba_result"
237
+ tier: int = Field(ge=1, le=3)
238
+ hint: str
239
+
240
+
241
+ class ToolError(_BaseResult):
242
+ kind: Literal["tool_error"] = "tool_error"
243
+ code: ToolErrorCode
244
+ message: str = Field(max_length=2_000)
245
+
246
+
247
+ ToolResult = Annotated[
248
+ ListTablesResult
249
+ | DescribeTableResult
250
+ | SampleRowsResult
251
+ | RunQueryResult
252
+ | ExplainQueryResult
253
+ | ReadChangelogResult
254
+ | SubmitRewriteResult
255
+ | ConsultDBAResult
256
+ | ToolError,
257
+ Field(discriminator="kind"),
258
+ ]
259
+
260
+
261
+ # The six reward-component keys match the composed rubric; tests and telemetry
262
+ # rely on this exact schema.
263
+ REWARD_COMPONENT_KEYS: tuple[str, ...] = (
264
+ "r_correct",
265
+ "r_drift",
266
+ "r_speedup",
267
+ "r_step_tax",
268
+ "r_gatekeepers",
269
+ "r_consult_dba",
270
+ )
271
+
272
+
273
+ # =============================================================================
274
+ # SqlDriftObservation
275
+ # =============================================================================
276
+
277
+
278
+ def _zero_reward_components() -> dict[str, float]:
279
+ """Six-key reward envelope initialised to zero.
280
+
281
+ Every observation, including the reset observation, carries the full
282
+ six-key schema so telemetry and tests can index it unconditionally.
283
+ """
284
+ return {key: 0.0 for key in REWARD_COMPONENT_KEYS}
285
+
286
+
287
+ class SqlDriftObservation(Observation):
288
+ """Observation returned by :meth:`SqlDriftEnvironment.step`.
289
+
290
+ Inherits `done: bool` and `reward: float | None` from base Observation.
291
+
292
+ The task payload (`baseline_sql`, `schema_synopsis`) is delivered on
293
+ the reset observation and kept empty on subsequent steps: the agent
294
+ is expected to capture it once and hold it in its own context.
295
+ """
296
+
297
+ step: int = Field(ge=0)
298
+ phase: EpisodePhase
299
+ last_tool: ToolName | None = None
300
+ tool_result: ToolResult | None = None
301
+ drift_fired: bool = False
302
+ drift_acknowledged: bool = False
303
+ learned_hints: str = Field(default="", max_length=800)
304
+ baseline_sql: str = Field(default="", max_length=10_000)
305
+ schema_synopsis: str = Field(default="", max_length=2_000)
306
+ budget_steps_remaining: int = Field(ge=0)
307
+ reward_components: dict[str, float] = Field(default_factory=_zero_reward_components)
308
+
309
+
310
+ # =============================================================================
311
+ # SqlDriftState — PUBLIC state (sanitized)
312
+ # =============================================================================
313
+
314
+
315
+ class SqlDriftState(State):
316
+ """Public state snapshot — serialized over `/state`.
317
+
318
+ Ground truth, DB handles, seeds, and baseline SQL live in
319
+ :class:`engine.runtime.RuntimeEpisodeState` and are never exposed here.
320
+ `extra="forbid"` guarantees no accidental leak via future field additions.
321
+ """
322
+
323
+ model_config = ConfigDict(
324
+ extra="forbid",
325
+ validate_assignment=True,
326
+ )
327
+
328
+ scenario_id: str
329
+ phase: EpisodePhase
330
+ budget_steps_remaining: int = Field(ge=0)
331
+ drift_fired: bool = False
332
+ consultations_used: int = Field(default=0, ge=0)
333
+ submitted: bool = False
334
+
335
+
336
+ __all__ = [
337
+ "ConsultDBAPayload",
338
+ "ConsultDBAResult",
339
+ "DescribeTablePayload",
340
+ "DescribeTableResult",
341
+ "EpisodePhase",
342
+ "ExplainQueryPayload",
343
+ "ExplainQueryResult",
344
+ "ListTablesPayload",
345
+ "ListTablesResult",
346
+ "REWARD_COMPONENT_KEYS",
347
+ "ReadChangelogPayload",
348
+ "ReadChangelogResult",
349
+ "RunQueryPayload",
350
+ "RunQueryResult",
351
+ "SampleRowsPayload",
352
+ "SampleRowsResult",
353
+ "SqlDriftAction",
354
+ "SqlDriftObservation",
355
+ "SqlDriftState",
356
+ "SubmitRewritePayload",
357
+ "SubmitRewriteResult",
358
+ "TOOL_TO_PAYLOAD_KIND",
359
+ "ToolError",
360
+ "ToolErrorCode",
361
+ "ToolName",
362
+ "ToolPayload",
363
+ "ToolResult",
364
+ ]
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_drift_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
pyproject.toml ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "sql_drift_env"
13
+ version = "0.1.0"
14
+ description = "SQLDrift: OpenEnv gym for repairing and optimizing SQL under live schema/business-rule drift"
15
+ requires-python = ">=3.12,<3.14"
16
+ dependencies = [
17
+ "duckdb>=1.5.2,<2.0",
18
+ "huggingface-hub>=0.20,<1.0",
19
+ "openenv-core[core]>=0.2.2,<0.4",
20
+ "sqlglot>=30.6.0,<40.0",
21
+ "pydantic>=2.8.0,<3.0",
22
+ "python-dotenv>=1.2.2,<2.0",
23
+ "openai>=2.32.0,<3.0",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ evidence = [
28
+ "matplotlib>=3.8.0,<4.0",
29
+ "pandas>=2.0.0,<3.0",
30
+ ]
31
+ dev = [
32
+ "mypy>=1.20.1",
33
+ "pytest>=9.0.3",
34
+ "pytest-asyncio>=1.3.0",
35
+ "pytest-cov>=7.0.0",
36
+ "ruff>=0.15.11",
37
+ "httpx>=0.28.0",
38
+ ]
39
+ train = [
40
+ "trl>=0.25.0,<1.0",
41
+ "datasets>=2.20.0,<4.0",
42
+ "transformers>=4.46.0,<5.0",
43
+ "accelerate>=1.0.0,<2.0",
44
+ "peft>=0.13,<1.0",
45
+ "bitsandbytes>=0.43,<1.0",
46
+ "unsloth>=2024.9,<2027.0",
47
+ ]
48
+
49
+ [project.scripts]
50
+ # Flat package layout; see _cli.py. Runtime path: site-packages/sql_drift_env/
51
+ server = "sql_drift_env._cli:main"
52
+
53
+ [tool.setuptools]
54
+ include-package-data = true
55
+ packages = [
56
+ "sql_drift_env",
57
+ "sql_drift_env.scenarios",
58
+ "sql_drift_env.engine",
59
+ "sql_drift_env.skill_library",
60
+ "sql_drift_env.actors",
61
+ "sql_drift_env.training",
62
+ "sql_drift_env.server",
63
+ "sql_drift_env.utilities",
64
+ ]
65
+
66
+ [tool.setuptools.package-dir]
67
+ "sql_drift_env" = "."
68
+ "sql_drift_env.scenarios" = "scenarios"
69
+ "sql_drift_env.engine" = "engine"
70
+ "sql_drift_env.skill_library" = "skill_library"
71
+ "sql_drift_env.actors" = "actors"
72
+ "sql_drift_env.training" = "training"
73
+ "sql_drift_env.server" = "server"
74
+ "sql_drift_env.utilities" = "utilities"
75
+
76
+ [tool.pytest.ini_options]
77
+ testpaths = ["tests"]
78
+ asyncio_mode = "auto"
79
+ markers = [
80
+ "slow: marks tests as slow (excluded by default; run via `pytest -m slow`)",
81
+ ]
82
+ # Default excludes slow tests so CI stays on a fast path; run `pytest -m slow` for full suite.
83
+ addopts = "-ra --strict-markers -m 'not slow'"
84
+
85
+ [tool.ruff]
86
+ line-length = 100
87
+ target-version = "py312"
88
+ extend-exclude = [".venv", "design", "*.egg-info", "*.ipynb"]
89
+
90
+ [tool.ruff.lint]
91
+ select = ["E", "F", "I", "B", "UP", "SIM"]
92
+ ignore = ["E501"]
93
+
94
+ # Prefer fixing structure over scattered inline suppressions. See each rule in Ruff docs.
95
+ [tool.ruff.lint.per-file-ignores]
96
+ "_cli.py" = ["PLC0415"]
97
+ "training/grpo_train.py" = ["PLC0415"]
98
+ "engine/profiler.py" = ["BLE001"]
99
+ "scenarios/__init__.py" = ["F401"]
100
+ "tests/unit/test_p0_smoke.py" = ["F401"]
101
+ "utilities/verbose_api_rollout.py" = ["E402"]
102
+ "utilities/demo_rollout.py" = ["E402"]
103
+ "tests/unit/test_drift_scenarios.py" = ["E402"]
104
+ "tests/unit/test_profiler.py" = ["E402"]
105
+ "tests/unit/test_drift.py" = ["E402"]
106
+ "tests/unit/test_runtime.py" = ["E402"]
107
+
108
+ [tool.mypy]
109
+ python_version = "3.12"
110
+ strict = true
111
+ ignore_missing_imports = true
112
+ explicit_package_bases = true
113
+ mypy_path = "."
114
+ files = [
115
+ "_cli.py",
116
+ "actors",
117
+ "client.py",
118
+ "engine",
119
+ "models.py",
120
+ "scenarios",
121
+ "utilities",
122
+ "server",
123
+ "skill_library",
124
+ "training",
125
+ ]
126
+ exclude = [
127
+ "\\.venv",
128
+ "design",
129
+ ".*\\.egg-info",
130
+ "tests",
131
+ "build",
132
+ "dist",
133
+ ]
134
+
135
+ # Subclasses in these modules call into openenv-core (no stubs). Relax only here.
136
+ [[tool.mypy.overrides]]
137
+ module = [
138
+ "client",
139
+ "models",
140
+ "engine.reward",
141
+ "server.app",
142
+ "server.sql_drift_env_environment",
143
+ "training.grpo_train",
144
+ ]
145
+ disallow_subclassing_any = false
146
+ warn_return_any = false
147
+
148
+ [tool.coverage.run]
149
+ # pytest-cov can hit upstream issues with beartype when tracing; run plain pytest for CI.
150
+ source = ["engine", "scenarios", "skill_library", "actors", "models.py"]
151
+ omit = ["tests/*", ".venv/*"]
152
+
153
+ [tool.coverage.report]
154
+ skip_empty = true
155
+ exclude_lines = [
156
+ "pragma: no cover",
157
+ "if __name__ == \"__main__\":",
158
+ "if TYPE_CHECKING:",
159
+ "raise NotImplementedError",
160
+ "if 0:",
161
+ "if False:",
162
+ "\\.\\.\\.",
163
+ ]
scenarios/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario registry + auto-discovery.
2
+
3
+ Every sibling module whose name starts with ``scenario_`` is expected to
4
+ export a module-level ``SPEC: ScenarioSpec``. This module walks the package
5
+ at import time and populates :data:`REGISTRY` so callers (env reset, tests,
6
+ CLI eval) can look scenarios up by id without knowing the file layout.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import importlib
12
+ import pkgutil
13
+
14
+ from .base import (
15
+ DriftConfig,
16
+ DriftKind,
17
+ Family,
18
+ ScenarioInstance,
19
+ ScenarioSpec,
20
+ )
21
+
22
+ REGISTRY: dict[str, ScenarioSpec] = {}
23
+
24
+
25
+ def _discover_specs() -> None:
26
+ """Import every `scenario_*` sibling module and harvest their SPEC."""
27
+ package = __name__ # "scenarios"
28
+ package_path = __path__ # provided by Python's package machinery
29
+
30
+ for info in pkgutil.iter_modules(package_path):
31
+ if not info.name.startswith("scenario_"):
32
+ continue
33
+ module = importlib.import_module(f"{package}.{info.name}")
34
+ spec = getattr(module, "SPEC", None)
35
+ if spec is None:
36
+ raise RuntimeError(f"{package}.{info.name} is missing a module-level `SPEC` export")
37
+ if not isinstance(spec, ScenarioSpec):
38
+ raise TypeError(
39
+ f"{package}.{info.name}.SPEC is {type(spec).__name__}; expected ScenarioSpec"
40
+ )
41
+ if spec.scenario_id in REGISTRY:
42
+ raise RuntimeError(
43
+ f"duplicate scenario_id {spec.scenario_id!r} — "
44
+ f"already registered from {REGISTRY[spec.scenario_id]!r}"
45
+ )
46
+ REGISTRY[spec.scenario_id] = spec
47
+
48
+
49
+ def iter_specs() -> list[ScenarioSpec]:
50
+ return sorted(REGISTRY.values(), key=lambda s: s.scenario_id)
51
+
52
+
53
+ def get_spec(scenario_id: str) -> ScenarioSpec:
54
+ try:
55
+ return REGISTRY[scenario_id]
56
+ except KeyError as e:
57
+ raise KeyError(f"unknown scenario_id={scenario_id!r}; known: {sorted(REGISTRY)}") from e
58
+
59
+
60
+ _discover_specs()
61
+
62
+
63
+ __all__ = [
64
+ "DriftConfig",
65
+ "DriftKind",
66
+ "Family",
67
+ "REGISTRY",
68
+ "ScenarioInstance",
69
+ "ScenarioSpec",
70
+ "get_spec",
71
+ "iter_specs",
72
+ ]
scenarios/_fixtures.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic fixture generators (stdlib only).
2
+
3
+ All distributions are backed by :class:`random.Random(seed)` so a given
4
+ ``(scenario_id, seed, scale)`` tuple always yields the same table contents.
5
+
6
+ Exposes a single :func:`seeded_rng` factory plus a handful of domain-specific
7
+ generators used by the concrete scenarios. No numpy/pandas runtime deps.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ import random
14
+ from collections.abc import Sequence
15
+ from datetime import UTC
16
+
17
+
18
+ def seeded_rng(*parts: int | str) -> random.Random:
19
+ """Derive a deterministic `random.Random` from mixed scalar parts.
20
+
21
+ Uses a stable 64-bit SplitMix-style hash over ``repr(parts)`` — avoids
22
+ Python's per-interpreter salted ``hash()`` for str.
23
+ """
24
+ h = 1469598103934665603 # FNV-1a 64-bit offset basis
25
+ for p in parts:
26
+ for byte in repr(p).encode():
27
+ h ^= byte
28
+ h = (h * 1099511628211) & 0xFFFFFFFFFFFFFFFF
29
+ return random.Random(h)
30
+
31
+
32
+ def lognormal_amounts(rng: random.Random, n: int, mu: float, sigma: float) -> list[float]:
33
+ """n log-normally distributed positive amounts, rounded to 2dp."""
34
+ return [round(rng.lognormvariate(mu, sigma), 2) for _ in range(n)]
35
+
36
+
37
+ def zipfian_choices(
38
+ rng: random.Random, population: Sequence[int], n: int, *, alpha: float = 1.07
39
+ ) -> list[int]:
40
+ """n draws from `population` with zipf(alpha) weights.
41
+
42
+ `alpha` controls skew: 1.07 is the project default. Uses
43
+ `random.choices` with explicit weights, so the draws are stable across
44
+ Python versions for a given `rng` state.
45
+ """
46
+ weights = [1.0 / ((i + 1) ** alpha) for i in range(len(population))]
47
+ return rng.choices(list(population), weights=weights, k=n)
48
+
49
+
50
+ def date_range_epoch_ms(
51
+ rng: random.Random,
52
+ n: int,
53
+ *,
54
+ start_epoch_ms: int,
55
+ window_days: int,
56
+ ) -> list[int]:
57
+ """n random timestamps (ms) within `[start, start + window_days)`."""
58
+ span_ms = window_days * 86_400_000
59
+ return [start_epoch_ms + rng.randrange(span_ms) for _ in range(n)]
60
+
61
+
62
+ def iso_strings_from_epoch_ms(epoch_ms: list[int]) -> list[str]:
63
+ """Convert epoch ms to ISO-8601 UTC strings (matching DuckDB's native coerce)."""
64
+ from datetime import datetime
65
+
66
+ return [
67
+ datetime.fromtimestamp(t / 1000, tz=UTC).isoformat().replace("+00:00", "Z")
68
+ for t in epoch_ms
69
+ ]
70
+
71
+
72
+ def categorical_choices(
73
+ rng: random.Random,
74
+ categories: Sequence[str],
75
+ n: int,
76
+ *,
77
+ weights: Sequence[float] | None = None,
78
+ ) -> list[str]:
79
+ return rng.choices(list(categories), weights=list(weights) if weights else None, k=n)
80
+
81
+
82
+ def unique_names(rng: random.Random, n: int, *, prefix: str = "name") -> list[str]:
83
+ """Stable pseudo-unique string IDs of the form `<prefix>_<64-bit-hex>`."""
84
+ return [f"{prefix}_{rng.getrandbits(64):016x}" for _ in range(n)]
85
+
86
+
87
+ def approx_normal(
88
+ rng: random.Random, n: int, *, mu: float, sigma: float, clip_lo: float | None = None
89
+ ) -> list[float]:
90
+ """n normal draws, optionally clipped below."""
91
+ out: list[float] = []
92
+ for _ in range(n):
93
+ x = rng.gauss(mu, sigma)
94
+ if clip_lo is not None and x < clip_lo:
95
+ x = clip_lo
96
+ out.append(round(x, 4))
97
+ return out
98
+
99
+
100
+ def sanity_nonzero_variance(xs: Sequence[float]) -> bool:
101
+ """Guard: reject obviously degenerate distributions (used in smoke tests)."""
102
+ if not xs:
103
+ return False
104
+ mean = sum(xs) / len(xs)
105
+ var = sum((x - mean) ** 2 for x in xs) / len(xs)
106
+ return var > 1e-12 and not math.isnan(var)
107
+
108
+
109
+ __all__ = [
110
+ "approx_normal",
111
+ "categorical_choices",
112
+ "date_range_epoch_ms",
113
+ "iso_strings_from_epoch_ms",
114
+ "lognormal_amounts",
115
+ "sanity_nonzero_variance",
116
+ "seeded_rng",
117
+ "unique_names",
118
+ "zipfian_choices",
119
+ ]
scenarios/base.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario spec + instance primitives.
2
+
3
+ Each concrete scenario file in :mod:`scenarios` exports:
4
+
5
+ - ``SPEC: ScenarioSpec`` — the immutable metadata (id, family, tags,
6
+ optional drift config) plus a bound ``builder`` callable.
7
+
8
+ The builder takes ``(spec, seed, scale)`` and returns a ready-to-attach
9
+ :class:`ScenarioInstance` whose DuckDB connection has been loaded with
10
+ deterministic fixtures, ground-truth hashes pre-computed, and baseline
11
+ runtime measured. ``base_scale`` is author-tuned per scenario so the
12
+ measured baseline clears :data:`BASELINE_MIN_MS` on a single build —
13
+ the old timing-driven reroll loop was removed because it coupled the
14
+ fixture RNG seed to the retry count, which destroyed determinism
15
+ whenever CI hit a jitter-induced retry.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Callable
21
+ from dataclasses import dataclass, field
22
+ from typing import TYPE_CHECKING, Any, Literal
23
+
24
+ if TYPE_CHECKING:
25
+ import duckdb
26
+
27
+ from utilities.logger import get_module_logger
28
+
29
+ _LOG = get_module_logger(__name__)
30
+
31
+ Family = Literal["ecommerce", "events", "cms", "saas_logs", "multitenant"]
32
+ DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"]
33
+ DifficultyLevel = Literal["easy", "normal", "hard"]
34
+
35
+ # Baseline floor — empirically calibrated, not a fixed marketing target
36
+ # 50 ms. DuckDB on in-memory fixtures at CI-reasonable scales (a few
37
+ # tens of thousands of rows) measures warm baselines of 0.4–2 ms on
38
+ # the anti-pattern queries shipped here; reaching 50 ms would require
39
+ # multi-minute fixture builds per scenario, which is untenable for
40
+ # both CI and RL rollouts (every reset rebuilds).
41
+ #
42
+ # 0.3 ms is ~3–5× the median-of-3 warm jitter floor on a quiet CPU
43
+ # (observed jitter ~60–100 µs). This SNR is tight but workable because
44
+ # the rubric gates the speedup reward at 1.2× before any
45
+ # credit is issued, so jitter-induced near-1× "speedups" score zero.
46
+ # The cap at 64× bounds upside. A 2× rewrite against a 0.3 ms baseline
47
+ # lands at 0.15 ms — still distinguishable from jitter under
48
+ # median-of-3 smoothing.
49
+ #
50
+ # The same floor applies in production and CI — no env-var escape
51
+ # hatch — so tests exercise the real reward distribution. Per-scenario
52
+ # overrides may raise *or* lower this floor when a scenario's query
53
+ # shape has a different natural baseline (see the field docstring on
54
+ # :class:`ScenarioSpec.baseline_min_ms`).
55
+ BASELINE_MIN_MS = 0.3
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class DriftConfig:
60
+ kind: DriftKind
61
+ payload: dict[str, Any]
62
+ min_step: int = 6
63
+ max_step: int = 12
64
+ cooldown_steps: int = 2
65
+
66
+ def __post_init__(self) -> None:
67
+ if self.min_step < 1:
68
+ raise ValueError("min_step must be >= 1")
69
+ if self.max_step < self.min_step:
70
+ raise ValueError("max_step must be >= min_step")
71
+ if self.cooldown_steps < 0:
72
+ raise ValueError("cooldown_steps must be >= 0")
73
+
74
+
75
+ @dataclass
76
+ class ScenarioInstance:
77
+ """Concretized scenario — ready-to-attach DuckDB fixture + ground truths."""
78
+
79
+ conn: duckdb.DuckDBPyConnection
80
+ baseline_sql: str
81
+ gt_sql_predrift: str
82
+ gt_sql_postdrift: str | None
83
+ baseline_runtime_ms: float
84
+ baseline_tokens: int
85
+ gt_result_hash_predrift: str
86
+ gt_result_hash_postdrift: str | None
87
+ drift_config: DriftConfig | None
88
+ schema_synopsis: str
89
+ # Drift-distinctive identifier sets consumed by the drift-adapt
90
+ # rubric. ``postdrift_identifiers`` marks identifiers/literals
91
+ # the correct post-drift rewrite MUST introduce; ``predrift_identifiers``
92
+ # marks identifiers/literals a submission that ignored the drift
93
+ # WOULD retain. Together they let the rubric distinguish "adapted"
94
+ # from "did not adapt" for drift kinds where a single identifier
95
+ # (e.g. ``ts`` under date-format drift) is shared by both sides.
96
+ postdrift_identifiers: frozenset[str] = field(default_factory=frozenset)
97
+ predrift_identifiers: frozenset[str] = field(default_factory=frozenset)
98
+
99
+
100
+ # Builder signature: (spec, seed, scale) -> (conn, baseline_sql,
101
+ # gt_sql_predrift, gt_sql_postdrift, schema_synopsis,
102
+ # postdrift_identifiers, predrift_identifiers).
103
+ BuilderResult = tuple[
104
+ "duckdb.DuckDBPyConnection",
105
+ str, # baseline_sql
106
+ str, # gt_sql_predrift
107
+ str | None, # gt_sql_postdrift
108
+ str, # schema_synopsis
109
+ frozenset[str], # postdrift_identifiers
110
+ frozenset[str], # predrift_identifiers
111
+ ]
112
+ BuilderFn = Callable[["ScenarioSpec", int, int], BuilderResult]
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class ScenarioSpec:
117
+ """Immutable scenario metadata + bound builder."""
118
+
119
+ scenario_id: str
120
+ family: Family
121
+ tags: frozenset[str]
122
+ drift_config: DriftConfig | None
123
+ builder: BuilderFn
124
+ # Row-count scale passed to the builder. Author-tuned so the
125
+ # measured baseline clears ``baseline_min_ms`` on a single build;
126
+ # materialize() emits a warning (but does not retry) if the floor
127
+ # is not met, signalling the author to bump this value.
128
+ base_scale: int = 1_000
129
+ # Per-scenario baseline floor override. Most scenarios inherit the
130
+ # module default. Scenarios whose query shape naturally lands at a
131
+ # very different baseline (e.g. a trivial single-table GROUP BY
132
+ # that can't be meaningfully sped up, or a large join whose raw
133
+ # shape is already expensive) can pin a different floor with a
134
+ # documented rationale at the SPEC site.
135
+ baseline_min_ms: float = BASELINE_MIN_MS
136
+
137
+ def materialize(self, seed: int, *, difficulty: DifficultyLevel = "normal") -> ScenarioInstance:
138
+ return materialize(self, seed, difficulty=difficulty)
139
+
140
+
141
+ def count_tokens(sql: str) -> int:
142
+ """Rough whitespace/punctuation token count — good enough for baseline."""
143
+ import re
144
+
145
+ return len(re.findall(r"[\w]+|[^\s\w]", sql))
146
+
147
+
148
+ def _scale_for_difficulty(base_scale: int, difficulty: DifficultyLevel) -> int:
149
+ """Map a coarse difficulty level onto the scenario builder's row-count scale."""
150
+ if difficulty == "easy":
151
+ return max(1, base_scale // 2)
152
+ if difficulty == "hard":
153
+ return base_scale * 2
154
+ return base_scale
155
+
156
+
157
+ def materialize(
158
+ spec: ScenarioSpec, seed: int, *, difficulty: DifficultyLevel = "normal"
159
+ ) -> ScenarioInstance:
160
+ """Build a ScenarioInstance once, measure baseline, and return.
161
+
162
+ Single build — deterministic, no retry. If the measured baseline is
163
+ below ``spec.baseline_min_ms`` a warning is logged so scenario
164
+ authors can bump ``base_scale``; the instance is still returned so
165
+ episodes can proceed (the rubric gracefully handles small
166
+ baselines via the 1.2× speedup gate and infinite-speedup cap).
167
+ """
168
+ from engine.profiler import median_of_3_warm_ms
169
+ from engine.verifier import canonical_row_hash
170
+
171
+ scale = _scale_for_difficulty(spec.base_scale, difficulty)
172
+
173
+ (
174
+ conn,
175
+ baseline_sql,
176
+ gt_pre,
177
+ gt_post,
178
+ synopsis,
179
+ postdrift_ids,
180
+ predrift_ids,
181
+ ) = spec.builder(spec, seed, scale)
182
+ try:
183
+ baseline_ms = median_of_3_warm_ms(conn, baseline_sql)
184
+ except Exception:
185
+ conn.close()
186
+ raise
187
+ if baseline_ms < spec.baseline_min_ms:
188
+ _LOG.warning(
189
+ "%s: baseline %.2fms < %.2fms floor at difficulty=%s scale=%d — bump base_scale",
190
+ spec.scenario_id,
191
+ baseline_ms,
192
+ spec.baseline_min_ms,
193
+ difficulty,
194
+ scale,
195
+ )
196
+
197
+ pre_rows = conn.execute(gt_pre).fetchall()
198
+ gt_hash_pre = canonical_row_hash(pre_rows)
199
+ # Post-drift ground-truth hashes are computed AFTER drift is applied
200
+ # at runtime — not here. The env backfills them from gt_post once
201
+ # drift fires.
202
+ return ScenarioInstance(
203
+ conn=conn,
204
+ baseline_sql=baseline_sql,
205
+ gt_sql_predrift=gt_pre,
206
+ gt_sql_postdrift=gt_post,
207
+ baseline_runtime_ms=baseline_ms,
208
+ baseline_tokens=count_tokens(baseline_sql),
209
+ gt_result_hash_predrift=gt_hash_pre,
210
+ gt_result_hash_postdrift=None,
211
+ drift_config=spec.drift_config,
212
+ schema_synopsis=synopsis,
213
+ postdrift_identifiers=postdrift_ids,
214
+ predrift_identifiers=predrift_ids,
215
+ )
216
+
217
+
218
+ __all__ = [
219
+ "BASELINE_MIN_MS",
220
+ "BuilderFn",
221
+ "BuilderResult",
222
+ "DifficultyLevel",
223
+ "DriftConfig",
224
+ "DriftKind",
225
+ "Family",
226
+ "ScenarioInstance",
227
+ "ScenarioSpec",
228
+ "count_tokens",
229
+ "materialize",
230
+ ]
scenarios/scenario_01_correlated_subquery.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 01 — correlated subquery anti-pattern.
2
+
3
+ Baseline computes a per-user order-count via a correlated subquery in the
4
+ projection list. The rewrite replaces it with a LEFT JOIN on a grouped
5
+ aggregate so the query runs once instead of once per outer row.
6
+
7
+ Schema: users(id, signup_month, tier), orders(id, user_id, amount, status).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import (
15
+ categorical_choices,
16
+ lognormal_amounts,
17
+ seeded_rng,
18
+ zipfian_choices,
19
+ )
20
+ from .base import BuilderResult, ScenarioSpec
21
+
22
+
23
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
24
+ rng = seeded_rng(spec.scenario_id, seed, scale)
25
+ n_users = scale
26
+ n_orders = scale * 6
27
+
28
+ user_ids = list(range(1, n_users + 1))
29
+ tiers = categorical_choices(
30
+ rng, ["free", "pro", "business"], n_users, weights=[0.7, 0.25, 0.05]
31
+ )
32
+ signup_months = rng.choices(list(range(1, 13)), k=n_users)
33
+
34
+ order_user_ids = zipfian_choices(rng, user_ids, n_orders)
35
+ amounts = lognormal_amounts(rng, n_orders, mu=3.2, sigma=0.9)
36
+ statuses = categorical_choices(
37
+ rng,
38
+ ["placed", "fulfilled", "refunded", "cancelled"],
39
+ n_orders,
40
+ weights=[0.6, 0.3, 0.05, 0.05],
41
+ )
42
+
43
+ conn = duckdb.connect(":memory:")
44
+ conn.execute("CREATE TABLE users( id BIGINT PRIMARY KEY, signup_month INTEGER, tier VARCHAR);")
45
+ conn.execute(
46
+ "CREATE TABLE orders("
47
+ " id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE, status VARCHAR);"
48
+ )
49
+ conn.executemany(
50
+ "INSERT INTO users VALUES (?, ?, ?)",
51
+ list(zip(user_ids, signup_months, tiers, strict=False)),
52
+ )
53
+ conn.executemany(
54
+ "INSERT INTO orders VALUES (?, ?, ?, ?)",
55
+ [
56
+ (oid, uid, amt, st)
57
+ for oid, (uid, amt, st) in enumerate(
58
+ zip(order_user_ids, amounts, statuses, strict=False), start=1
59
+ )
60
+ ],
61
+ )
62
+
63
+ baseline_sql = (
64
+ "SELECT u.id, u.tier, "
65
+ "(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id AND o.status = 'fulfilled') "
66
+ " AS fulfilled_orders "
67
+ "FROM users u "
68
+ "ORDER BY u.id"
69
+ )
70
+ gt_sql_predrift = (
71
+ "SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders "
72
+ "FROM users u LEFT JOIN ("
73
+ " SELECT user_id, COUNT(*) AS n FROM orders "
74
+ " WHERE status = 'fulfilled' GROUP BY user_id"
75
+ ") c ON c.user_id = u.id "
76
+ "ORDER BY u.id"
77
+ )
78
+
79
+ synopsis = (
80
+ "users(id PK, signup_month, tier); orders(id PK, user_id→users.id, amount, status). "
81
+ "Baseline scans orders once per user via a correlated subquery."
82
+ )
83
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
84
+
85
+
86
+ SPEC = ScenarioSpec(
87
+ scenario_id="01_correlated_subquery",
88
+ family="ecommerce",
89
+ tags=frozenset({"correlated_subquery", "projection_subquery", "ecommerce"}),
90
+ drift_config=None,
91
+ builder=_build,
92
+ base_scale=800,
93
+ )
scenarios/scenario_02_select_star_join.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 02 — SELECT * over a join anti-pattern.
2
+
3
+ Baseline pulls every column from three joined tables when only two columns
4
+ are needed. The rewrite projects exactly the requested columns.
5
+
6
+ Schema: products(id, sku, name, description, price_cents, …), orders(id, user_id, created_at),
7
+ order_items(order_id, product_id, qty, unit_price_cents).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import (
15
+ categorical_choices,
16
+ lognormal_amounts,
17
+ seeded_rng,
18
+ unique_names,
19
+ zipfian_choices,
20
+ )
21
+ from .base import BuilderResult, ScenarioSpec
22
+
23
+
24
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
25
+ rng = seeded_rng(spec.scenario_id, seed, scale)
26
+ n_products = max(100, scale // 4)
27
+ n_orders = scale
28
+ n_items = scale * 3
29
+ n_users = max(50, scale // 8)
30
+
31
+ product_ids = list(range(1, n_products + 1))
32
+ skus = unique_names(rng, n_products, prefix="sku")
33
+ names = unique_names(rng, n_products, prefix="p")
34
+ descriptions = [f"Long marketing copy for {n}" * 6 for n in names] # wide col
35
+ prices = [int(x * 100) for x in lognormal_amounts(rng, n_products, mu=3.0, sigma=1.1)]
36
+ categories = categorical_choices(
37
+ rng, ["books", "electronics", "apparel", "grocery", "home"], n_products
38
+ )
39
+
40
+ user_ids = list(range(1, n_users + 1))
41
+ order_user_ids = zipfian_choices(rng, user_ids, n_orders)
42
+ order_created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)]
43
+
44
+ item_order_ids = rng.choices(list(range(1, n_orders + 1)), k=n_items)
45
+ item_product_ids = zipfian_choices(rng, product_ids, n_items)
46
+ item_qty = rng.choices([1, 1, 1, 2, 2, 3, 4, 5], k=n_items)
47
+ item_unit_price = [prices[pid - 1] for pid in item_product_ids]
48
+
49
+ conn = duckdb.connect(":memory:")
50
+ conn.execute(
51
+ "CREATE TABLE products("
52
+ " id BIGINT PRIMARY KEY, sku VARCHAR, name VARCHAR, description VARCHAR,"
53
+ " price_cents BIGINT, category VARCHAR);"
54
+ )
55
+ conn.execute(
56
+ "CREATE TABLE orders( id BIGINT PRIMARY KEY, user_id BIGINT, created_at_epoch_s BIGINT);"
57
+ )
58
+ conn.execute(
59
+ "CREATE TABLE order_items("
60
+ " order_id BIGINT, product_id BIGINT, qty INTEGER, unit_price_cents BIGINT);"
61
+ )
62
+ conn.executemany(
63
+ "INSERT INTO products VALUES (?, ?, ?, ?, ?, ?)",
64
+ list(zip(product_ids, skus, names, descriptions, prices, categories, strict=False)),
65
+ )
66
+ conn.executemany(
67
+ "INSERT INTO orders VALUES (?, ?, ?)",
68
+ [
69
+ (oid, uid, ts)
70
+ for oid, (uid, ts) in enumerate(zip(order_user_ids, order_created, strict=False), 1)
71
+ ],
72
+ )
73
+ conn.executemany(
74
+ "INSERT INTO order_items VALUES (?, ?, ?, ?)",
75
+ list(zip(item_order_ids, item_product_ids, item_qty, item_unit_price, strict=False)),
76
+ )
77
+
78
+ # Anti-pattern: SELECT * inside a subquery that wraps the real join, then
79
+ # the outer query projects only a handful of columns. DuckDB materializes
80
+ # every column of the subquery before the projection can prune it.
81
+ baseline_sql = (
82
+ "SELECT t.order_id, t.name, t.qty "
83
+ "FROM ("
84
+ " SELECT * FROM order_items oi "
85
+ " JOIN products p ON p.id = oi.product_id "
86
+ " JOIN orders o ON o.id = oi.order_id "
87
+ " WHERE p.category = 'books' AND oi.qty >= 2"
88
+ ") t "
89
+ "ORDER BY t.order_id, t.name"
90
+ )
91
+ gt_sql_predrift = (
92
+ "SELECT oi.order_id, p.name, oi.qty "
93
+ "FROM order_items oi "
94
+ "JOIN products p ON p.id = oi.product_id "
95
+ "JOIN orders o ON o.id = oi.order_id "
96
+ "WHERE p.category = 'books' AND oi.qty >= 2 "
97
+ "ORDER BY oi.order_id, p.name"
98
+ )
99
+
100
+ synopsis = (
101
+ "products(id PK, sku, name, description, price_cents, category); "
102
+ "orders(id PK, user_id, created_at_epoch_s); "
103
+ "order_items(order_id, product_id, qty, unit_price_cents). "
104
+ "Baseline wraps a three-way join with SELECT * inside a subquery; "
105
+ "only (order_id, product name, qty) are needed downstream."
106
+ )
107
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
108
+
109
+
110
+ SPEC = ScenarioSpec(
111
+ scenario_id="02_select_star_join",
112
+ family="ecommerce",
113
+ tags=frozenset({"select_star", "over_projection", "join", "ecommerce"}),
114
+ drift_config=None,
115
+ builder=_build,
116
+ # Three-way join with SELECT * is already well-optimized by DuckDB;
117
+ # a larger base keeps first-try baseline above the 1 ms floor.
118
+ base_scale=1_500,
119
+ )
scenarios/scenario_03_cartesian_join.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 03 — accidental cartesian join.
2
+
3
+ Baseline joins events × tenants without an equijoin clause and relies on a
4
+ post-filter `WHERE` to restore tenant scoping. The optimizer cannot push
5
+ the filter into the join, so it materializes the full |events|·|tenants|
6
+ cross-product first. The rewrite moves the tenant key into an ON clause.
7
+
8
+ Schema: tenants(id, tier), events(id, tenant_id, kind, severity).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import duckdb
14
+
15
+ from ._fixtures import categorical_choices, seeded_rng, zipfian_choices
16
+ from .base import BuilderResult, ScenarioSpec
17
+
18
+
19
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
20
+ rng = seeded_rng(spec.scenario_id, seed, scale)
21
+ n_tenants = max(20, scale // 40)
22
+ n_events = scale * 4
23
+
24
+ tenant_ids = list(range(1, n_tenants + 1))
25
+ tiers = categorical_choices(
26
+ rng, ["free", "pro", "enterprise"], n_tenants, weights=[0.6, 0.3, 0.1]
27
+ )
28
+ event_tenant_ids = zipfian_choices(rng, tenant_ids, n_events)
29
+ kinds = categorical_choices(
30
+ rng,
31
+ ["login", "logout", "action", "error"],
32
+ n_events,
33
+ weights=[0.35, 0.3, 0.3, 0.05],
34
+ )
35
+ severities = categorical_choices(
36
+ rng,
37
+ ["info", "warn", "error", "critical"],
38
+ n_events,
39
+ weights=[0.7, 0.2, 0.08, 0.02],
40
+ )
41
+
42
+ conn = duckdb.connect(":memory:")
43
+ conn.execute("CREATE TABLE tenants(id BIGINT PRIMARY KEY, tier VARCHAR);")
44
+ conn.execute(
45
+ "CREATE TABLE events("
46
+ " id BIGINT PRIMARY KEY, tenant_id BIGINT, kind VARCHAR, severity VARCHAR);"
47
+ )
48
+ conn.executemany(
49
+ "INSERT INTO tenants VALUES (?, ?)", list(zip(tenant_ids, tiers, strict=False))
50
+ )
51
+ conn.executemany(
52
+ "INSERT INTO events VALUES (?, ?, ?, ?)",
53
+ [
54
+ (i, tid, k, s)
55
+ for i, (tid, k, s) in enumerate(
56
+ zip(event_tenant_ids, kinds, severities, strict=False), start=1
57
+ )
58
+ ],
59
+ )
60
+
61
+ baseline_sql = (
62
+ "SELECT t.tier, COUNT(*) AS n "
63
+ "FROM events e, tenants t "
64
+ "WHERE t.id = e.tenant_id + 0 " # defeat optimizer recognition
65
+ "AND e.severity IN ('error', 'critical') "
66
+ "GROUP BY t.tier ORDER BY t.tier"
67
+ )
68
+ gt_sql_predrift = (
69
+ "SELECT t.tier, COUNT(*) AS n "
70
+ "FROM events e JOIN tenants t ON t.id = e.tenant_id "
71
+ "WHERE e.severity IN ('error', 'critical') "
72
+ "GROUP BY t.tier ORDER BY t.tier"
73
+ )
74
+
75
+ synopsis = (
76
+ "tenants(id PK, tier); events(id PK, tenant_id→tenants.id, kind, severity). "
77
+ "Baseline relies on a WHERE-clause equijoin obscured by arithmetic, "
78
+ "forcing a cartesian materialization."
79
+ )
80
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
81
+
82
+
83
+ SPEC = ScenarioSpec(
84
+ scenario_id="03_cartesian_join",
85
+ family="events",
86
+ tags=frozenset({"cartesian", "missing_join_condition", "events"}),
87
+ drift_config=None,
88
+ builder=_build,
89
+ # Cartesian materialization cost grows with |events|·|tenants| so
90
+ # larger base_scale keeps first-try baseline above the 1 ms floor;
91
+ # the reroll loop doubles from here if needed.
92
+ base_scale=1_200,
93
+ )
scenarios/scenario_04_distinct_groupby.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 04 — redundant DISTINCT on top of GROUP BY.
2
+
3
+ Baseline applies DISTINCT after already grouping, forcing a second
4
+ deduplication over a result whose GROUP BY keys are already unique.
5
+ Rewrite drops the DISTINCT entirely.
6
+
7
+ Schema: pageviews(id, session_id, path, ts_epoch_s).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices
15
+ from .base import BuilderResult, ScenarioSpec
16
+
17
+
18
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
19
+ rng = seeded_rng(spec.scenario_id, seed, scale)
20
+ n_sessions = max(200, scale // 4)
21
+ n_views = scale * 8
22
+
23
+ session_ids = unique_names(rng, n_sessions, prefix="sess")
24
+ paths = categorical_choices(
25
+ rng,
26
+ [f"/path/{p}" for p in ["home", "about", "product", "cart", "checkout", "help"]],
27
+ n_views,
28
+ )
29
+ view_sessions = zipfian_choices(rng, list(range(n_sessions)), n_views)
30
+ view_session_ids = [session_ids[i] for i in view_sessions]
31
+ view_ts = [1_700_000_000 + rng.randrange(30 * 86_400) for _ in range(n_views)]
32
+
33
+ conn = duckdb.connect(":memory:")
34
+ conn.execute(
35
+ "CREATE TABLE pageviews("
36
+ " id BIGINT PRIMARY KEY, session_id VARCHAR, path VARCHAR, ts_epoch_s BIGINT);"
37
+ )
38
+ conn.executemany(
39
+ "INSERT INTO pageviews VALUES (?, ?, ?, ?)",
40
+ [
41
+ (i, sid, p, ts)
42
+ for i, (sid, p, ts) in enumerate(
43
+ zip(view_session_ids, paths, view_ts, strict=False), start=1
44
+ )
45
+ ],
46
+ )
47
+
48
+ baseline_sql = (
49
+ "SELECT DISTINCT session_id, path, COUNT(*) AS hits "
50
+ "FROM pageviews GROUP BY session_id, path "
51
+ "ORDER BY session_id, path"
52
+ )
53
+ gt_sql_predrift = (
54
+ "SELECT session_id, path, COUNT(*) AS hits "
55
+ "FROM pageviews GROUP BY session_id, path "
56
+ "ORDER BY session_id, path"
57
+ )
58
+
59
+ synopsis = (
60
+ "pageviews(id PK, session_id, path, ts_epoch_s). "
61
+ "Baseline applies redundant DISTINCT on a GROUP BY whose keys are already unique."
62
+ )
63
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
64
+
65
+
66
+ SPEC = ScenarioSpec(
67
+ scenario_id="04_distinct_groupby",
68
+ family="saas_logs",
69
+ tags=frozenset({"distinct", "redundant_distinct", "group_by", "saas_logs"}),
70
+ drift_config=None,
71
+ builder=_build,
72
+ base_scale=500,
73
+ )
scenarios/scenario_05_nested_subquery.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 05 — deeply nested subquery anti-pattern.
2
+
3
+ Baseline stacks two layers of `IN (SELECT ...)` where a single join + GROUP
4
+ BY expresses the same intent. The rewrite flattens the nesting.
5
+
6
+ Schema: articles(id, author_id, published_at), comments(id, article_id, author_id),
7
+ authors(id, display_name, is_verified).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices
15
+ from .base import BuilderResult, ScenarioSpec
16
+
17
+
18
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
19
+ rng = seeded_rng(spec.scenario_id, seed, scale)
20
+ n_authors = max(50, scale // 8)
21
+ n_articles = scale
22
+ n_comments = scale * 4
23
+
24
+ author_ids = list(range(1, n_authors + 1))
25
+ display_names = unique_names(rng, n_authors, prefix="author")
26
+ verified = rng.choices([True, False], weights=[0.2, 0.8], k=n_authors)
27
+
28
+ article_ids = list(range(1, n_articles + 1))
29
+ article_authors = zipfian_choices(rng, author_ids, n_articles)
30
+ article_published = [1_700_000_000 + rng.randrange(120 * 86_400) for _ in range(n_articles)]
31
+ article_statuses = categorical_choices(
32
+ rng, ["draft", "published", "archived"], n_articles, weights=[0.2, 0.7, 0.1]
33
+ )
34
+
35
+ comment_ids = list(range(1, n_comments + 1))
36
+ comment_articles = rng.choices(article_ids, k=n_comments)
37
+ comment_authors = zipfian_choices(rng, author_ids, n_comments)
38
+
39
+ conn = duckdb.connect(":memory:")
40
+ conn.execute(
41
+ "CREATE TABLE authors( id BIGINT PRIMARY KEY, display_name VARCHAR, is_verified BOOLEAN);"
42
+ )
43
+ conn.execute(
44
+ "CREATE TABLE articles("
45
+ " id BIGINT PRIMARY KEY, author_id BIGINT, published_at_epoch_s BIGINT,"
46
+ " status VARCHAR);"
47
+ )
48
+ conn.execute(
49
+ "CREATE TABLE comments( id BIGINT PRIMARY KEY, article_id BIGINT, author_id BIGINT);"
50
+ )
51
+ conn.executemany(
52
+ "INSERT INTO authors VALUES (?, ?, ?)",
53
+ list(zip(author_ids, display_names, verified, strict=False)),
54
+ )
55
+ conn.executemany(
56
+ "INSERT INTO articles VALUES (?, ?, ?, ?)",
57
+ list(zip(article_ids, article_authors, article_published, article_statuses, strict=False)),
58
+ )
59
+ conn.executemany(
60
+ "INSERT INTO comments VALUES (?, ?, ?)",
61
+ list(zip(comment_ids, comment_articles, comment_authors, strict=False)),
62
+ )
63
+
64
+ baseline_sql = (
65
+ "SELECT display_name "
66
+ "FROM authors "
67
+ "WHERE id IN ("
68
+ " SELECT author_id FROM comments "
69
+ " WHERE article_id IN ("
70
+ " SELECT id FROM articles WHERE status = 'published'"
71
+ " )"
72
+ ") "
73
+ "ORDER BY display_name"
74
+ )
75
+ gt_sql_predrift = (
76
+ "SELECT DISTINCT a.display_name "
77
+ "FROM authors a "
78
+ "JOIN comments c ON c.author_id = a.id "
79
+ "JOIN articles ar ON ar.id = c.article_id "
80
+ "WHERE ar.status = 'published' "
81
+ "ORDER BY a.display_name"
82
+ )
83
+
84
+ synopsis = (
85
+ "authors(id PK, display_name, is_verified); "
86
+ "articles(id PK, author_id→authors.id, published_at_epoch_s, status); "
87
+ "comments(id PK, article_id→articles.id, author_id→authors.id). "
88
+ "Baseline chains two IN-subqueries where one JOIN + DISTINCT suffices."
89
+ )
90
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
91
+
92
+
93
+ SPEC = ScenarioSpec(
94
+ scenario_id="05_nested_subquery",
95
+ family="cms",
96
+ tags=frozenset({"nested_subquery", "in_subquery", "cms"}),
97
+ drift_config=None,
98
+ builder=_build,
99
+ base_scale=1_500,
100
+ )
scenarios/scenario_06_having_as_where.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 06 — HAVING used as WHERE.
2
+
3
+ Baseline filters on a grouping key inside HAVING, which forces the engine
4
+ to group first and filter after. The rewrite moves the non-aggregate
5
+ predicate into WHERE.
6
+
7
+ Schema: orders(id, tenant_id, user_id, amount_cents, status, created_at_epoch_s).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import categorical_choices, lognormal_amounts, seeded_rng, zipfian_choices
15
+ from .base import BuilderResult, ScenarioSpec
16
+
17
+
18
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
19
+ rng = seeded_rng(spec.scenario_id, seed, scale)
20
+ n_tenants = max(40, scale // 40)
21
+ n_users = max(200, scale // 4)
22
+ n_orders = scale * 10
23
+
24
+ tenant_ids = list(range(1, n_tenants + 1))
25
+ user_ids = list(range(1, n_users + 1))
26
+ user_tenants = rng.choices(tenant_ids, k=n_users)
27
+
28
+ order_user = zipfian_choices(rng, user_ids, n_orders)
29
+ order_tenant = [user_tenants[u - 1] for u in order_user]
30
+ order_amount = [int(x * 100) for x in lognormal_amounts(rng, n_orders, mu=3.5, sigma=0.8)]
31
+ statuses = categorical_choices(
32
+ rng,
33
+ ["placed", "fulfilled", "refunded", "cancelled"],
34
+ n_orders,
35
+ weights=[0.55, 0.3, 0.1, 0.05],
36
+ )
37
+ created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)]
38
+
39
+ conn = duckdb.connect(":memory:")
40
+ conn.execute(
41
+ "CREATE TABLE orders("
42
+ " id BIGINT PRIMARY KEY, tenant_id BIGINT, user_id BIGINT,"
43
+ " amount_cents BIGINT, status VARCHAR, created_at_epoch_s BIGINT);"
44
+ )
45
+ conn.executemany(
46
+ "INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)",
47
+ [
48
+ (i, t, u, a, s, ts)
49
+ for i, (t, u, a, s, ts) in enumerate(
50
+ zip(order_tenant, order_user, order_amount, statuses, created, strict=False),
51
+ start=1,
52
+ )
53
+ ],
54
+ )
55
+
56
+ # Anti-pattern: filter on a row-level column (status) that IS in GROUP BY
57
+ # inside HAVING, forcing the engine to build groups for every status value
58
+ # before discarding most of them. The rewrite moves the filter into WHERE
59
+ # so aggregation only runs over rows we actually care about.
60
+ baseline_sql = (
61
+ "SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
62
+ "FROM orders "
63
+ "GROUP BY tenant_id, status "
64
+ "HAVING status = 'fulfilled' "
65
+ " AND SUM(amount_cents) >= 100000 "
66
+ "ORDER BY tenant_id"
67
+ )
68
+ gt_sql_predrift = (
69
+ "SELECT tenant_id, status, SUM(amount_cents) AS total_cents "
70
+ "FROM orders "
71
+ "WHERE status = 'fulfilled' "
72
+ "GROUP BY tenant_id, status "
73
+ "HAVING SUM(amount_cents) >= 100000 "
74
+ "ORDER BY tenant_id"
75
+ )
76
+
77
+ synopsis = (
78
+ "orders(id PK, tenant_id, user_id, amount_cents, status, created_at_epoch_s). "
79
+ "Baseline filters `status` inside HAVING, forcing aggregation over "
80
+ "every status group before discarding all but 'fulfilled'."
81
+ )
82
+ return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset()
83
+
84
+
85
+ SPEC = ScenarioSpec(
86
+ scenario_id="06_having_as_where",
87
+ family="ecommerce",
88
+ tags=frozenset({"having_as_where", "aggregate_filter", "ecommerce"}),
89
+ drift_config=None,
90
+ builder=_build,
91
+ base_scale=1_500,
92
+ )
scenarios/scenario_07_drift_column_rename.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 07 — drift: column rename (orders.user_id → orders.account_id).
2
+
3
+ Baseline groups orders by `user_id`. When the drift fires, the column is
4
+ renamed; the agent must rewrite its query against `account_id`. Row
5
+ semantics are unchanged — the post-drift hash equals the pre-drift hash
6
+ because the only thing that changed is the column label.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import duckdb
12
+
13
+ from ._fixtures import lognormal_amounts, seeded_rng, zipfian_choices
14
+ from .base import BuilderResult, DriftConfig, ScenarioSpec
15
+
16
+
17
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
18
+ rng = seeded_rng(spec.scenario_id, seed, scale)
19
+ n_users = max(200, scale // 2)
20
+ n_orders = scale * 4
21
+
22
+ user_ids = list(range(1, n_users + 1))
23
+ order_users = zipfian_choices(rng, user_ids, n_orders)
24
+ amounts = lognormal_amounts(rng, n_orders, mu=3.0, sigma=0.8)
25
+
26
+ conn = duckdb.connect(":memory:")
27
+ conn.execute("CREATE TABLE users(id BIGINT PRIMARY KEY, signup_month INTEGER);")
28
+ conn.execute("CREATE TABLE orders(id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE);")
29
+ conn.executemany(
30
+ "INSERT INTO users VALUES (?, ?)",
31
+ [(uid, (uid % 12) + 1) for uid in user_ids],
32
+ )
33
+ conn.executemany(
34
+ "INSERT INTO orders VALUES (?, ?, ?)",
35
+ [(i, uid, amt) for i, (uid, amt) in enumerate(zip(order_users, amounts, strict=False), 1)],
36
+ )
37
+
38
+ baseline_sql = (
39
+ "SELECT user_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total "
40
+ "FROM orders GROUP BY user_id ORDER BY user_id"
41
+ )
42
+ gt_sql_predrift = baseline_sql # static part — baseline IS correct pre-drift
43
+ gt_sql_postdrift = (
44
+ "SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total "
45
+ "FROM orders GROUP BY account_id ORDER BY account_id"
46
+ )
47
+
48
+ synopsis = (
49
+ "users(id PK, signup_month); orders(id PK, user_id→users.id, amount). "
50
+ "Under drift, orders.user_id is renamed to orders.account_id."
51
+ )
52
+ return (
53
+ conn,
54
+ baseline_sql,
55
+ gt_sql_predrift,
56
+ gt_sql_postdrift,
57
+ synopsis,
58
+ frozenset({"account_id"}),
59
+ frozenset({"user_id"}),
60
+ )
61
+
62
+
63
+ SPEC = ScenarioSpec(
64
+ scenario_id="07_drift_column_rename",
65
+ family="ecommerce",
66
+ tags=frozenset({"drift", "column_rename", "ecommerce"}),
67
+ drift_config=DriftConfig(
68
+ kind="column_rename",
69
+ payload={"table": "orders", "old": "user_id", "new": "account_id"},
70
+ ),
71
+ builder=_build,
72
+ base_scale=2_000,
73
+ )
scenarios/scenario_08_drift_date_format.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 08 — drift: date format (events.ts iso_string → epoch_ms).
2
+
3
+ Baseline filters events inside a specific UTC day via ISO string
4
+ comparisons. When the drift fires, the `ts` column becomes a BIGINT of
5
+ epoch-ms; the agent must rewrite comparisons against the numeric value.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from datetime import UTC, datetime, timedelta
11
+
12
+ import duckdb
13
+
14
+ from ._fixtures import categorical_choices, seeded_rng
15
+ from .base import BuilderResult, DriftConfig, ScenarioSpec
16
+
17
+ # Fixed anchor day — comparisons are deterministic across seeds.
18
+ _ANCHOR = datetime(2026, 4, 21, tzinfo=UTC)
19
+ _ANCHOR_NEXT = _ANCHOR + timedelta(days=1)
20
+ _ANCHOR_ISO = _ANCHOR.isoformat().replace("+00:00", "Z")
21
+ _ANCHOR_NEXT_ISO = _ANCHOR_NEXT.isoformat().replace("+00:00", "Z")
22
+ _ANCHOR_MS = int(_ANCHOR.timestamp() * 1000)
23
+ _ANCHOR_NEXT_MS = int(_ANCHOR_NEXT.timestamp() * 1000)
24
+
25
+
26
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
27
+ rng = seeded_rng(spec.scenario_id, seed, scale)
28
+ n_events = scale * 8
29
+
30
+ window_start = _ANCHOR - timedelta(days=3)
31
+ window_span_s = 7 * 86_400
32
+ event_dts = [
33
+ (window_start + timedelta(seconds=rng.randrange(window_span_s))) for _ in range(n_events)
34
+ ]
35
+ event_iso = [dt.isoformat().replace("+00:00", "Z") for dt in event_dts]
36
+ kinds = categorical_choices(
37
+ rng, ["login", "action", "error"], n_events, weights=[0.6, 0.35, 0.05]
38
+ )
39
+
40
+ conn = duckdb.connect(":memory:")
41
+ conn.execute("CREATE TABLE events( id BIGINT PRIMARY KEY, ts VARCHAR, kind VARCHAR);")
42
+ conn.executemany(
43
+ "INSERT INTO events VALUES (?, ?, ?)",
44
+ [(i, iso, k) for i, (iso, k) in enumerate(zip(event_iso, kinds, strict=False), 1)],
45
+ )
46
+
47
+ baseline_sql = (
48
+ f"SELECT kind, COUNT(*) AS n FROM events "
49
+ f"WHERE ts >= '{_ANCHOR_ISO}' AND ts < '{_ANCHOR_NEXT_ISO}' "
50
+ f"GROUP BY kind ORDER BY kind"
51
+ )
52
+ gt_sql_predrift = baseline_sql
53
+ gt_sql_postdrift = (
54
+ f"SELECT kind, COUNT(*) AS n FROM events "
55
+ f"WHERE ts >= {_ANCHOR_MS} AND ts < {_ANCHOR_NEXT_MS} "
56
+ f"GROUP BY kind ORDER BY kind"
57
+ )
58
+
59
+ synopsis = (
60
+ "events(id PK, ts VARCHAR(ISO-8601 UTC), kind). Under drift, `ts` "
61
+ f"becomes BIGINT epoch-ms. Filter window is {_ANCHOR_ISO} – "
62
+ f"{_ANCHOR_NEXT_ISO} (i.e. epoch-ms in "
63
+ f"[{_ANCHOR_MS}, {_ANCHOR_NEXT_MS}))."
64
+ )
65
+ # Date-format drift keeps the ``ts`` identifier; what changes is the
66
+ # literal shape (ISO string → epoch-ms integer). The rubric
67
+ # therefore can't distinguish "adapted" from "not adapted" on
68
+ # identifiers alone, so we expose the ISO anchor strings as the
69
+ # pre-drift distinctive set and leave postdrift empty — the rubric
70
+ # treats absence-of-predrift-markers as adaptation whenever
71
+ # ``postdrift_identifiers`` is empty (DriftAdapt rubric case).
72
+ return (
73
+ conn,
74
+ baseline_sql,
75
+ gt_sql_predrift,
76
+ gt_sql_postdrift,
77
+ synopsis,
78
+ frozenset(),
79
+ frozenset({_ANCHOR_ISO, _ANCHOR_NEXT_ISO}),
80
+ )
81
+
82
+
83
+ SPEC = ScenarioSpec(
84
+ scenario_id="08_drift_date_format",
85
+ family="events",
86
+ tags=frozenset({"drift", "date_format", "iso_to_epoch", "events"}),
87
+ drift_config=DriftConfig(
88
+ kind="date_format",
89
+ payload={"table": "events", "col": "ts", "from": "iso_string", "to": "epoch_ms"},
90
+ ),
91
+ builder=_build,
92
+ base_scale=500,
93
+ )
scenarios/scenario_09_drift_enum_rule.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 09 — drift: enum split ('active' → 'ACTIVE' / 'ACTIVE_V2').
2
+
3
+ Baseline counts tenants whose status is 'active'. After drift, every
4
+ previously-'active' row has been relabelled to either 'ACTIVE' or
5
+ 'ACTIVE_V2' (deterministic round-robin). The agent must filter on the
6
+ union of the new values to recover the business-equivalent count.
7
+
8
+ Note: unlike 07 and 10, the post-drift data changed, so the post-drift
9
+ ground-truth hash is computed against the post-drift rows — the agent's
10
+ result set now reflects the new status values.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import duckdb
16
+
17
+ from ._fixtures import categorical_choices, seeded_rng
18
+ from .base import BuilderResult, DriftConfig, ScenarioSpec
19
+
20
+
21
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
22
+ rng = seeded_rng(spec.scenario_id, seed, scale)
23
+ n_tenants = max(400, scale)
24
+
25
+ statuses = categorical_choices(
26
+ rng,
27
+ ["active", "trial", "suspended", "churned"],
28
+ n_tenants,
29
+ weights=[0.55, 0.2, 0.15, 0.1],
30
+ )
31
+ tiers = categorical_choices(
32
+ rng, ["free", "pro", "business"], n_tenants, weights=[0.6, 0.3, 0.1]
33
+ )
34
+
35
+ conn = duckdb.connect(":memory:")
36
+ conn.execute("CREATE TABLE tenants( id BIGINT PRIMARY KEY, status VARCHAR, tier VARCHAR);")
37
+ conn.executemany(
38
+ "INSERT INTO tenants VALUES (?, ?, ?)",
39
+ [(i, s, t) for i, (s, t) in enumerate(zip(statuses, tiers, strict=False), 1)],
40
+ )
41
+
42
+ # Pre-drift: all 'active' tenants; post-drift: union of the new labels.
43
+ baseline_sql = (
44
+ "SELECT tier, COUNT(*) AS n "
45
+ "FROM tenants WHERE status = 'active' "
46
+ "GROUP BY tier ORDER BY tier"
47
+ )
48
+ gt_sql_predrift = baseline_sql
49
+ gt_sql_postdrift = (
50
+ "SELECT tier, COUNT(*) AS n "
51
+ "FROM tenants WHERE status IN ('ACTIVE', 'ACTIVE_V2') "
52
+ "GROUP BY tier ORDER BY tier"
53
+ )
54
+
55
+ synopsis = (
56
+ "tenants(id PK, status, tier). Under drift, status='active' is split "
57
+ "into 'ACTIVE' and 'ACTIVE_V2'; 'trial'/'suspended'/'churned' are unchanged."
58
+ )
59
+ return (
60
+ conn,
61
+ baseline_sql,
62
+ gt_sql_predrift,
63
+ gt_sql_postdrift,
64
+ synopsis,
65
+ frozenset({"ACTIVE", "ACTIVE_V2"}),
66
+ frozenset({"active"}),
67
+ )
68
+
69
+
70
+ SPEC = ScenarioSpec(
71
+ scenario_id="09_drift_enum_rule",
72
+ family="multitenant",
73
+ tags=frozenset({"drift", "enum_rule", "business_rule", "multitenant"}),
74
+ drift_config=DriftConfig(
75
+ kind="enum_rule",
76
+ payload={
77
+ "table": "tenants",
78
+ "col": "status",
79
+ "old_value": "active",
80
+ "new_values": ["ACTIVE", "ACTIVE_V2"],
81
+ },
82
+ ),
83
+ builder=_build,
84
+ base_scale=600,
85
+ )
scenarios/scenario_10_drift_field_deprecation.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario 10 — drift: inline string col replaced by FK lookup.
2
+
3
+ Baseline groups posts by the inline `author_name` string. Under drift, a
4
+ `users(id, full_name)` lookup is created, `posts.author_name` is dropped,
5
+ and `posts.users_id` is added with a backfilled FK. The correct rewrite
6
+ joins through `users` and groups by `full_name`.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import duckdb
12
+
13
+ from ._fixtures import seeded_rng, unique_names, zipfian_choices
14
+ from .base import BuilderResult, DriftConfig, ScenarioSpec
15
+
16
+
17
+ def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult:
18
+ rng = seeded_rng(spec.scenario_id, seed, scale)
19
+ n_authors = max(40, scale // 10)
20
+ n_posts = scale * 4
21
+
22
+ author_names = unique_names(rng, n_authors, prefix="author")
23
+ post_author_idx = zipfian_choices(rng, list(range(n_authors)), n_posts)
24
+ post_author_names = [author_names[i] for i in post_author_idx]
25
+
26
+ conn = duckdb.connect(":memory:")
27
+ conn.execute("CREATE TABLE posts( id BIGINT PRIMARY KEY, author_name VARCHAR, body VARCHAR);")
28
+ conn.executemany(
29
+ "INSERT INTO posts VALUES (?, ?, ?)",
30
+ [(i, n, f"body {i}") for i, n in enumerate(post_author_names, 1)],
31
+ )
32
+
33
+ baseline_sql = (
34
+ "SELECT author_name, COUNT(*) AS n_posts "
35
+ "FROM posts GROUP BY author_name "
36
+ "ORDER BY author_name"
37
+ )
38
+ gt_sql_predrift = baseline_sql
39
+ gt_sql_postdrift = (
40
+ "SELECT u.full_name AS author_name, COUNT(*) AS n_posts "
41
+ "FROM posts p JOIN users u ON u.id = p.users_id "
42
+ "GROUP BY u.full_name ORDER BY u.full_name"
43
+ )
44
+
45
+ synopsis = (
46
+ "posts(id PK, author_name, body). Under drift, posts.author_name is "
47
+ "deprecated; a new users(id PK, full_name) table is created and "
48
+ "posts gains a users_id FK. Rewrites must JOIN through users."
49
+ )
50
+ return (
51
+ conn,
52
+ baseline_sql,
53
+ gt_sql_predrift,
54
+ gt_sql_postdrift,
55
+ synopsis,
56
+ frozenset({"users", "users_id", "full_name"}),
57
+ frozenset({"author_name"}),
58
+ )
59
+
60
+
61
+ SPEC = ScenarioSpec(
62
+ scenario_id="10_drift_field_deprecation",
63
+ family="cms",
64
+ tags=frozenset({"drift", "field_deprecation", "fk_backfill", "cms"}),
65
+ drift_config=DriftConfig(
66
+ kind="field_deprecation",
67
+ payload={
68
+ "orig": ("posts", "author_name"),
69
+ "lookup": ("users", "id", "full_name"),
70
+ },
71
+ ),
72
+ builder=_build,
73
+ base_scale=1_500,
74
+ )
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """SQLDrift server package."""
2
+
3
+ from server.sql_drift_env_environment import SqlDriftEnvironment
4
+
5
+ __all__ = ["SqlDriftEnvironment"]
server/app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI factory for the SQLDrift environment.
2
+
3
+ ``create_app()`` returns a fully-wired FastAPI app exposing the
4
+ stateless HTTP routes (``/health``, ``/schema``, ``/reset``, ``/step``)
5
+ and the stateful ``/ws`` WebSocket session. Stateful multi-step
6
+ episodes must go through ``/ws``; each HTTP ``/step`` spawns a
7
+ fresh env instance that is ``close()``-d in ``finally`` (one env per request).
8
+
9
+ ``main()`` runs the server with Uvicorn — exported as the
10
+ ``[project.scripts] sql-drift-server`` entry point.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Any
16
+ from uuid import uuid4
17
+
18
+ from openenv.core.env_server.http_server import create_app as _openenv_create_app
19
+
20
+ from models import SqlDriftAction, SqlDriftObservation
21
+ from skill_library import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs
22
+
23
+ from . import settings
24
+ from .sql_drift_env_environment import SqlDriftEnvironment
25
+
26
+ ENV_NAME = "sql_drift_env"
27
+ DEFAULT_MAX_CONCURRENT_ENVS = settings.MAX_CONCURRENT_ENVS
28
+ _SESSION_STORE_ROOT = DEFAULT_STORE_DIR / "sessions"
29
+
30
+ # Purge stale session directories left by previous server runs before
31
+ # accepting any traffic. Failures are non-fatal.
32
+ _startup_removed = cleanup_stale_session_dirs(
33
+ _SESSION_STORE_ROOT, settings.SKILL_STORE_SESSION_TTL_HOURS
34
+ )
35
+ if _startup_removed:
36
+ import logging as _logging
37
+
38
+ _logging.getLogger("sql_drift_env.app.server.app").info(
39
+ "startup: removed %d stale session skill-store dirs from %s",
40
+ _startup_removed,
41
+ _SESSION_STORE_ROOT,
42
+ )
43
+
44
+
45
+ def _create_server_environment() -> SqlDriftEnvironment:
46
+ """Build one server-managed env with its own on-disk skill library.
47
+
48
+ ``cleanup_on_close=True`` ensures the session directory is deleted when
49
+ the WebSocket session ends, preventing unbounded on-disk session growth.
50
+ """
51
+ session_dir = _SESSION_STORE_ROOT / uuid4().hex
52
+ return SqlDriftEnvironment(
53
+ skill_store=Store(directory=session_dir),
54
+ cleanup_on_close=True,
55
+ )
56
+
57
+
58
+ def create_app(max_concurrent_envs: int | None = None) -> Any:
59
+ """Build the FastAPI app bound to a fresh-env factory per session."""
60
+ if max_concurrent_envs is None:
61
+ max_concurrent_envs = DEFAULT_MAX_CONCURRENT_ENVS
62
+ return _openenv_create_app(
63
+ env=_create_server_environment,
64
+ action_cls=SqlDriftAction,
65
+ observation_cls=SqlDriftObservation,
66
+ env_name=ENV_NAME,
67
+ max_concurrent_envs=max_concurrent_envs,
68
+ )
69
+
70
+
71
+ def main(host: str = settings.SERVER_HOST, port: int = settings.SERVER_PORT) -> None:
72
+ """Uvicorn entry point — matches the [project.scripts] wiring."""
73
+ import uvicorn
74
+
75
+ uvicorn.run(create_app(), host=host, port=port)
76
+
77
+
78
+ # Module-level app instance for uvicorn's ``module:attr`` syntax
79
+ # (``uvicorn server.app:app``) and the ``openenv.yaml`` ``app:`` field.
80
+ # Built at import time; safe because the OpenEnv factory only stores the
81
+ # environment factory and instantiates per request / session.
82
+ app = create_app()
83
+
84
+
85
+ __all__ = ["ENV_NAME", "app", "create_app", "main"]
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
server/requirements.txt ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-dev --no-hashes --no-emit-project --no-annotate -o server/requirements.txt
3
+ aiofile==3.9.0
4
+ annotated-doc==0.0.4
5
+ annotated-types==0.7.0
6
+ anyio==4.13.0
7
+ attrs==26.1.0
8
+ audioop-lts==0.2.2 ; python_full_version >= '3.13'
9
+ authlib==1.7.0
10
+ beartype==0.22.9
11
+ brotli==1.2.0
12
+ cachetools==7.0.6
13
+ caio==0.9.25
14
+ certifi==2026.2.25
15
+ cffi==2.0.0 ; platform_python_implementation != 'PyPy'
16
+ charset-normalizer==3.4.7
17
+ click==8.3.2
18
+ colorama==0.4.6 ; sys_platform == 'win32'
19
+ cryptography==46.0.7
20
+ cyclopts==4.10.2
21
+ distro==1.9.0
22
+ dnspython==2.8.0
23
+ docstring-parser==0.18.0
24
+ docutils==0.22.4
25
+ duckdb==1.5.2
26
+ email-validator==2.3.0
27
+ exceptiongroup==1.3.1
28
+ fastapi==0.136.0
29
+ fastmcp==3.2.4
30
+ filelock==3.29.0
31
+ fsspec==2026.2.0
32
+ gradio==6.13.0
33
+ gradio-client==2.5.0
34
+ griffelib==2.0.2
35
+ groovy==0.1.2
36
+ h11==0.16.0
37
+ hf-gradio==0.4.1
38
+ hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
39
+ httpcore==1.0.9
40
+ httpx==0.28.1
41
+ httpx-sse==0.4.3
42
+ huggingface-hub==1.11.0
43
+ idna==3.11
44
+ importlib-metadata==8.7.1
45
+ jaraco-classes==3.4.0
46
+ jaraco-context==6.1.2
47
+ jaraco-functools==4.4.0
48
+ jeepney==0.9.0 ; sys_platform == 'linux'
49
+ jinja2==3.1.6
50
+ jiter==0.14.0
51
+ joserfc==1.6.4
52
+ jsonref==1.1.0
53
+ jsonschema==4.26.0
54
+ jsonschema-path==0.4.5
55
+ jsonschema-specifications==2025.9.1
56
+ keyring==25.7.0
57
+ markdown-it-py==4.0.0
58
+ markupsafe==3.0.3
59
+ mcp==1.27.0
60
+ mdurl==0.1.2
61
+ more-itertools==11.0.2
62
+ numpy==2.4.4
63
+ openai==2.32.0
64
+ openapi-pydantic==0.5.1
65
+ openenv-core==0.2.3
66
+ opentelemetry-api==1.41.0
67
+ orjson==3.11.8
68
+ packaging==26.1
69
+ pandas==3.0.2
70
+ pathable==0.5.0
71
+ pillow==12.2.0
72
+ platformdirs==4.9.6
73
+ py-key-value-aio==0.4.4
74
+ pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
75
+ pydantic==2.13.3
76
+ pydantic-core==2.46.3
77
+ pydantic-settings==2.14.0
78
+ pydub==0.25.1
79
+ pygments==2.20.0
80
+ pyjwt==2.12.1
81
+ pyperclip==1.11.0
82
+ python-dateutil==2.9.0.post0
83
+ python-dotenv==1.2.2
84
+ python-multipart==0.0.26
85
+ pytz==2026.1.post1
86
+ pywin32==311 ; sys_platform == 'win32'
87
+ pywin32-ctypes==0.2.3 ; sys_platform == 'win32'
88
+ pyyaml==6.0.3
89
+ referencing==0.37.0
90
+ requests==2.33.1
91
+ rich==15.0.0
92
+ rich-rst==1.3.2
93
+ rpds-py==0.30.0
94
+ safehttpx==0.1.7
95
+ secretstorage==3.5.0 ; sys_platform == 'linux'
96
+ semantic-version==2.10.0
97
+ shellingham==1.5.4
98
+ six==1.17.0
99
+ sniffio==1.3.1
100
+ sqlglot==30.6.0
101
+ sse-starlette==3.3.4
102
+ starlette==1.0.0
103
+ tomli==2.4.1
104
+ tomli-w==1.2.0
105
+ tomlkit==0.14.0
106
+ tqdm==4.67.3
107
+ typer==0.24.1
108
+ typing-extensions==4.15.0
109
+ typing-inspection==0.4.2
110
+ tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32'
111
+ uncalled-for==0.3.1
112
+ urllib3==2.6.3
113
+ uvicorn==0.44.0
114
+ watchfiles==1.1.1
115
+ websockets==16.0
116
+ zipp==3.23.1
server/settings.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Server/runtime settings loaded from the repo-local `.env`."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from utilities.env_loader import env_float, env_int, env_str
6
+
7
+ SERVER_HOST = env_str("SQL_DRIFT_SERVER_HOST", "0.0.0.0")
8
+ SERVER_PORT = env_int("SQL_DRIFT_SERVER_PORT", 8000, min_value=1, max_value=65535)
9
+ MAX_CONCURRENT_ENVS = env_int("SQL_DRIFT_MAX_CONCURRENT_ENVS", 4, min_value=1)
10
+
11
+ DEFAULT_STEP_BUDGET = env_int("SQL_DRIFT_DEFAULT_STEP_BUDGET", 25, min_value=1)
12
+ MAX_RESULT_ROWS = env_int("SQL_DRIFT_MAX_RESULT_ROWS", 1_000, min_value=1)
13
+ QUERY_TIMEOUT_S = env_float("SQL_DRIFT_QUERY_TIMEOUT_S", 2.0, min_value=0.001)
14
+
15
+ # Session skill-store directories older than this many hours are removed at
16
+ # startup and when the owning environment is closed. Set to 0 to disable
17
+ # TTL-based cleanup (directories will still be removed on close when
18
+ # cleanup_on_close=True is set for a server-managed environment).
19
+ SKILL_STORE_SESSION_TTL_HOURS = env_float(
20
+ "SQL_DRIFT_SKILL_STORE_SESSION_TTL_HOURS", 24.0, min_value=0.0
21
+ )
22
+
23
+ __all__ = [
24
+ "DEFAULT_STEP_BUDGET",
25
+ "MAX_CONCURRENT_ENVS",
26
+ "MAX_RESULT_ROWS",
27
+ "QUERY_TIMEOUT_S",
28
+ "SERVER_HOST",
29
+ "SERVER_PORT",
30
+ "SKILL_STORE_SESSION_TTL_HOURS",
31
+ ]
server/sql_drift_env_environment.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv ``Environment`` implementation for SQLDrift.
2
+
3
+ Responsibilities:
4
+
5
+ * Own the private :class:`engine.runtime.RuntimeEpisodeState` and the
6
+ composite :class:`engine.reward.SqlDriftRubric` for the current episode.
7
+ * Dispatch each of the eight tool-call payloads to a dedicated
8
+ ``_handle_<tool>`` method that returns a typed
9
+ :class:`models.ToolResult` (or :class:`models.ToolError`).
10
+ * Fire drift on a schedule blended with a cooldown: ``max(scheduled,
11
+ first_run_query_step + cooldown)`` before the agent acts on the step
12
+ where drift applies, then recompute the post-drift ground truth hash.
13
+ * Publish public observations (:class:`models.SqlDriftObservation`) and a
14
+ strictly sanitised public state snapshot (:class:`models.SqlDriftState`).
15
+
16
+ Privacy: ``self._runtime`` holds the DuckDB handle, ground-truth hashes,
17
+ baseline runtime, and seed. They stay inside this class; the rubric reads
18
+ them via a closure, and ``env.state`` exposes only a fixed whitelist of fields.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import contextlib
24
+ import math
25
+ import re
26
+ import secrets
27
+ from random import Random
28
+ from typing import TYPE_CHECKING, Any, Literal
29
+
30
+ import duckdb
31
+ import sqlglot
32
+ from openenv.core.env_server.interfaces import Environment
33
+ from pydantic import BaseModel, ConfigDict, Field
34
+
35
+ from actors import dba_oracle
36
+ from actors.engineering_manager import author_changelog
37
+ from engine.drift import apply_drift
38
+ from engine.profiler import (
39
+ QueryWatchdogEscalationError,
40
+ execute_hash_timed,
41
+ execute_once_timed,
42
+ execute_once_with_columns,
43
+ )
44
+ from engine.reward import (
45
+ SPEEDUP_CAP_FOR_INFTY,
46
+ STEP_REBATE_DESCRIBE_TABLE,
47
+ STEP_REBATE_EXPLAIN_QUERY,
48
+ STEP_REBATE_LIST_TABLES,
49
+ STEP_REBATE_READ_CHANGELOG,
50
+ STEP_REBATE_RUN_QUERY,
51
+ STEP_REBATE_SAMPLE_ROWS,
52
+ SqlDriftRubric,
53
+ canonicalize_sql,
54
+ effective_speedup,
55
+ )
56
+ from engine.runtime import RuntimeEpisodeState
57
+ from engine.verifier import canonical_row_hash
58
+ from models import (
59
+ REWARD_COMPONENT_KEYS,
60
+ ConsultDBAPayload,
61
+ ConsultDBAResult,
62
+ DescribeTablePayload,
63
+ DescribeTableResult,
64
+ EpisodePhase,
65
+ ExplainQueryPayload,
66
+ ExplainQueryResult,
67
+ ListTablesPayload,
68
+ ListTablesResult,
69
+ ReadChangelogPayload,
70
+ ReadChangelogResult,
71
+ RunQueryPayload,
72
+ RunQueryResult,
73
+ SampleRowsPayload,
74
+ SampleRowsResult,
75
+ SqlDriftAction,
76
+ SqlDriftObservation,
77
+ SqlDriftState,
78
+ SubmitRewritePayload,
79
+ SubmitRewriteResult,
80
+ ToolError,
81
+ ToolErrorCode,
82
+ ToolResult,
83
+ )
84
+ from scenarios import REGISTRY, get_spec
85
+ from skill_library import PlaybookEntry, Store, load_all, retrieve
86
+ from utilities.logger import get_module_logger, log_env_reset, log_env_step, log_interaction
87
+
88
+ from . import settings
89
+
90
+ if TYPE_CHECKING:
91
+ from scenarios.base import ScenarioSpec
92
+
93
+ _LOG = get_module_logger(__name__)
94
+
95
+ DEFAULT_STEP_BUDGET: int = settings.DEFAULT_STEP_BUDGET
96
+ MAX_RESULT_ROWS: int = settings.MAX_RESULT_ROWS
97
+ QUERY_TIMEOUT_S: float = settings.QUERY_TIMEOUT_S
98
+
99
+
100
+ class _ResetOptions(BaseModel):
101
+ model_config = ConfigDict(extra="ignore")
102
+
103
+ scenario_id: str | None = None
104
+ enable_dba_oracle: bool | None = None
105
+ difficulty: Literal["easy", "normal", "hard"] = "normal"
106
+ budget_steps: int = Field(default=DEFAULT_STEP_BUDGET, ge=1)
107
+
108
+
109
+ _READ_ONLY_EXPRESSION_KEYS: frozenset[str] = frozenset({"select", "with"})
110
+
111
+ # DuckDB exposes a family of table-valued functions and scalar helpers
112
+ # that read from the host filesystem or leak introspection state —
113
+ # ``read_csv``, ``read_parquet``, ``read_json``, ``read_text``,
114
+ # ``parquet_metadata``, ``duckdb_secrets``, ``glob``, etc. They are
115
+ # *technically* SELECT-shaped calls so the statement-key check alone
116
+ # admits them. We reject any function whose lowercased name starts with
117
+ # one of these prefixes or exactly matches one of the known-dangerous
118
+ # standalone names. Agent-facing SQL has no legitimate need for any of
119
+ # them — the DuckDB connection is pre-populated by the scenario builder.
120
+ _DENYLIST_PREFIXES: tuple[str, ...] = (
121
+ "read_",
122
+ "write_",
123
+ "copy_",
124
+ "duckdb_",
125
+ "pragma_",
126
+ "sniff_",
127
+ "parquet_",
128
+ "arrow_",
129
+ "json_table",
130
+ "json_each",
131
+ "sqlite_",
132
+ "load_",
133
+ "install_",
134
+ )
135
+ _DENYLIST_EXACT: frozenset[str] = frozenset(
136
+ {
137
+ "glob",
138
+ "attach",
139
+ "detach",
140
+ "checkpoint",
141
+ "force_checkpoint",
142
+ "set_secret",
143
+ "create_secret",
144
+ "drop_secret",
145
+ "enable_profiling",
146
+ "disable_profiling",
147
+ "enable_object_cache",
148
+ }
149
+ )
150
+
151
+
152
+ def _is_denylisted_function_name(name: str) -> bool:
153
+ """Return True iff ``name`` (case-insensitively) matches a sandbox-escape."""
154
+ lowered = name.lower()
155
+ if lowered in _DENYLIST_EXACT:
156
+ return True
157
+ return any(lowered.startswith(p) for p in _DENYLIST_PREFIXES)
158
+
159
+
160
+ def _function_names(node: sqlglot.exp.Func) -> list[str]:
161
+ """All plausible names to check against the denylist for one AST node.
162
+
163
+ sqlglot lowers a few DuckDB calls into dedicated expression classes
164
+ (``ReadCSV``, ``ReadParquet``, …) whose ``.name`` is actually the
165
+ first positional arg — the file path — not the function name. We
166
+ recover the function name from the class name in that case and fall
167
+ back to ``.name`` for the ``Anonymous`` form that covers everything
168
+ else. Including both lets one denylist lookup cover both lowerings.
169
+ """
170
+ cls = type(node).__name__
171
+ out: list[str] = []
172
+ # Derive a snake-case function name from the class name. We insert
173
+ # an underscore at two kinds of CamelCase boundaries:
174
+ #
175
+ # * ``aB`` — normal lower-to-upper (``ReadParquet`` → ``read_parquet``)
176
+ # * ``ABc`` — end of an acronym run (``ReadCSVAuto`` → ``read_csv_auto``)
177
+ #
178
+ # Purely-lowercase class names (``Anonymous``) produce no prefix
179
+ # match; we fall through to ``.name`` below for those.
180
+ if cls and cls[0].isupper():
181
+ snake = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", cls).lower()
182
+ out.append(snake)
183
+ name_attr = getattr(node, "name", None)
184
+ if isinstance(name_attr, str) and name_attr:
185
+ out.append(name_attr)
186
+ return out
187
+
188
+
189
+ _VALID_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
190
+
191
+
192
+ def _resolve_timeout_s(timeout_s: float | None) -> float:
193
+ """Caller-supplied per-step timeout or the module default.
194
+
195
+ ``timeout_s`` is accepted on every OpenEnv ``step()`` (the abstract
196
+ base mandates the keyword). When the caller provides a positive
197
+ value we honour it as the wall-clock budget for any DuckDB query
198
+ this step runs; ``None`` and non-positive values fall back to the
199
+ module-level :data:`QUERY_TIMEOUT_S` so a mis-configured client
200
+ cannot silently disable the watchdog.
201
+ """
202
+ if timeout_s is None or timeout_s <= 0:
203
+ return QUERY_TIMEOUT_S
204
+ return float(timeout_s)
205
+
206
+
207
+ def _initial_schema_synopsis(spec: ScenarioSpec, synopsis: str) -> str:
208
+ """Reset-time synopsis with future drift details removed.
209
+
210
+ Drift scenarios should not reveal the exact schema/business-rule
211
+ change before the changelog is published at runtime. We therefore
212
+ trim the authored synopsis at the first ``" Under drift"`` clause on
213
+ reset and only surface the pre-drift schema shape.
214
+ """
215
+ if spec.drift_config is None:
216
+ return synopsis
217
+ predrift, marker, _ = synopsis.partition(" Under drift")
218
+ return predrift if marker else synopsis
219
+
220
+
221
+ def _validate_read_only_sql(sql: str) -> None:
222
+ """Reject anything that isn't a single-statement read-only SELECT/CTE.
223
+
224
+ Raises ``ValueError`` so the caller can translate to a typed
225
+ :class:`models.ToolError` with :attr:`ToolErrorCode.INVALID_TOOL_ARGUMENT`.
226
+ This is the only place that mediates what the policy may execute;
227
+ scenario builders and drift DDL call DuckDB directly with privileged
228
+ SQL and deliberately bypass this check.
229
+
230
+ Beyond the statement-level gate, this walker also rejects two
231
+ sandbox-escape vectors that would otherwise ride along inside a
232
+ perfectly-shaped SELECT:
233
+
234
+ 1. Table-valued functions that read from the host filesystem
235
+ (``read_csv``, ``read_parquet``, ``read_json_auto``, ``glob``,
236
+ ``read_text``, …) or leak engine introspection (``duckdb_secrets``
237
+ carries credentials; ``duckdb_settings`` /``duckdb_functions``
238
+ can enumerate available exploits). See :data:`_DENYLIST_PREFIXES`
239
+ / :data:`_DENYLIST_EXACT`.
240
+ 2. ``SELECT * FROM 'path/to/x.csv'`` — DuckDB treats a bare string
241
+ literal in a FROM clause as a filesystem path and auto-detects
242
+ the format. There is no function node to inspect in this form,
243
+ so we separately reject any :class:`sqlglot.exp.Table` whose
244
+ backing expression is a string literal.
245
+ """
246
+ try:
247
+ statements = sqlglot.parse(sql, dialect="duckdb")
248
+ except sqlglot.errors.ParseError as exc:
249
+ raise ValueError(f"SQL failed to parse: {exc}") from exc
250
+
251
+ non_empty = [s for s in statements if s is not None]
252
+ if len(non_empty) != 1:
253
+ raise ValueError("multi-statement SQL is not allowed; submit one SELECT")
254
+ expr = non_empty[0]
255
+ if expr.key not in _READ_ONLY_EXPRESSION_KEYS:
256
+ raise ValueError(
257
+ f"only read-only SELECT/CTE queries are allowed (got {expr.key.upper()} statement)"
258
+ )
259
+
260
+ for node in expr.walk():
261
+ # (1) Function-valued sandbox escapes. Inspect both the class
262
+ # name (catches ``ReadCSV`` / ``ReadParquet`` lowerings where
263
+ # ``.name`` holds the file path, not the function name) and
264
+ # ``.name`` (catches the generic ``Anonymous`` form).
265
+ if isinstance(node, sqlglot.exp.Func):
266
+ for fn_name in _function_names(node):
267
+ if _is_denylisted_function_name(fn_name):
268
+ raise ValueError(
269
+ f"function {fn_name!r} is not allowed — agent-facing SQL may "
270
+ "only touch the scenario's in-memory tables"
271
+ )
272
+ # (2) Bare-path FROM form: ``SELECT * FROM 'x.csv'`` or
273
+ # ``SELECT * FROM '/etc/passwd'``. sqlglot normalises both
274
+ # single- and double-quoted identifiers to
275
+ # ``Identifier(quoted=True)``, so we can't rely on the quote
276
+ # flavour to distinguish a file path from a legitimately-quoted
277
+ # table name. Instead we require every agent-facing table name
278
+ # to be a valid unquoted SQL identifier — the scenarios never
279
+ # emit anything else, and paths always contain ``/``, ``.`` or
280
+ # ``~`` which fail the identifier regex.
281
+ if isinstance(node, sqlglot.exp.Table):
282
+ inner = node.this
283
+ if isinstance(inner, sqlglot.exp.Identifier):
284
+ ident_name = inner.name
285
+ if ident_name and not _VALID_IDENTIFIER_RE.match(ident_name):
286
+ raise ValueError(
287
+ f"table identifier {ident_name!r} is not a valid unquoted SQL "
288
+ "name — reading from file paths or other engine-specific "
289
+ "resources is not allowed"
290
+ )
291
+
292
+
293
+ class SqlDriftEnvironment(Environment[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
294
+ """OpenEnv environment for SQL repair + optimization under schema drift."""
295
+
296
+ SUPPORTS_CONCURRENT_SESSIONS = True
297
+
298
+ def __init__(
299
+ self,
300
+ skill_store: Store | None = None,
301
+ cleanup_on_close: bool = False,
302
+ ) -> None:
303
+ self._runtime: RuntimeEpisodeState | None = None
304
+ self._skill_store: Store | None = skill_store
305
+ # When True, the skill-store directory is deleted when close() is called.
306
+ # Set this for server-managed per-session stores so disk usage doesn't grow
307
+ # monotonically; see design/codereview.md (session store issue).
308
+ self._cleanup_on_close: bool = cleanup_on_close
309
+ super().__init__(
310
+ rubric=SqlDriftRubric(ctx_provider=lambda: self._require_runtime()),
311
+ )
312
+
313
+ # ------------------------------------------------------------------
314
+ # OpenEnv contract
315
+ # ------------------------------------------------------------------
316
+
317
+ @log_env_reset
318
+ def reset(
319
+ self,
320
+ seed: int | None = None,
321
+ episode_id: str | None = None,
322
+ **kwargs: Any,
323
+ ) -> SqlDriftObservation:
324
+ options = _ResetOptions.model_validate(kwargs)
325
+ scenario_id = options.scenario_id
326
+ enable_dba_oracle = dba_oracle.is_enabled(options.enable_dba_oracle)
327
+ difficulty = options.difficulty
328
+ budget_steps = options.budget_steps
329
+
330
+ if seed is None:
331
+ seed = secrets.randbits(31)
332
+ if episode_id is None:
333
+ episode_id = f"ep-{seed:08x}"
334
+ if scenario_id is None:
335
+ scenario_id = self._pick_scenario_for_seed(seed)
336
+
337
+ spec = get_spec(scenario_id)
338
+ instance = spec.materialize(seed, difficulty=difficulty)
339
+
340
+ drift_scheduled_step: int | None = None
341
+ if instance.drift_config is not None:
342
+ drift_scheduled_step = Random(seed).randint(
343
+ instance.drift_config.min_step,
344
+ instance.drift_config.max_step,
345
+ )
346
+
347
+ self._close_existing_runtime()
348
+ self._runtime = RuntimeEpisodeState(
349
+ episode_id=episode_id,
350
+ seed=seed,
351
+ scenario_id=scenario_id,
352
+ instance=instance,
353
+ conn=instance.conn,
354
+ gt_result_hash_predrift=instance.gt_result_hash_predrift,
355
+ gt_result_hash_postdrift=None,
356
+ baseline_runtime_ms=instance.baseline_runtime_ms,
357
+ baseline_tokens=instance.baseline_tokens,
358
+ baseline_sql_canonical=canonicalize_sql(instance.baseline_sql),
359
+ baseline_postdrift_raises=False,
360
+ drift_scheduled_step=drift_scheduled_step,
361
+ budget_steps=budget_steps,
362
+ dba_oracle_enabled=enable_dba_oracle,
363
+ )
364
+
365
+ self._reset_rubric()
366
+
367
+ learned_hints = kwargs.get("learned_hints")
368
+ if learned_hints is None:
369
+ learned_hints = self._render_learned_hints(spec, include_drift_cards=False)
370
+ if len(learned_hints) > 800:
371
+ learned_hints = learned_hints[:800]
372
+
373
+ rt = self._require_runtime()
374
+ return SqlDriftObservation(
375
+ step=0,
376
+ phase=EpisodePhase.DIAGNOSE,
377
+ last_tool=None,
378
+ tool_result=None,
379
+ drift_fired=False,
380
+ drift_acknowledged=False,
381
+ learned_hints=learned_hints,
382
+ baseline_sql=instance.baseline_sql,
383
+ schema_synopsis=_initial_schema_synopsis(spec, instance.schema_synopsis),
384
+ budget_steps_remaining=rt.budget_steps_remaining,
385
+ reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
386
+ done=False,
387
+ reward=None,
388
+ )
389
+
390
+ @log_env_step
391
+ def step(
392
+ self,
393
+ action: SqlDriftAction,
394
+ timeout_s: float | None = None,
395
+ **kwargs: Any,
396
+ ) -> SqlDriftObservation:
397
+ rt = self._require_runtime()
398
+ if rt.submitted or rt.budget_steps_remaining <= 0:
399
+ raise ValueError("Episode is already finished; call reset() to start a new episode.")
400
+ rt.step_count += 1
401
+ rt.last_step_was_tool_error = False
402
+ rt.last_step_was_repeat_failing_query = False
403
+ rt.last_step_repeat_failing_query_count = 0
404
+ rt.last_step_productive_rebate = 0.0
405
+
406
+ self._maybe_fire_drift()
407
+
408
+ effective_timeout_s = _resolve_timeout_s(timeout_s)
409
+ try:
410
+ tool_result = self._dispatch(action, timeout_s=effective_timeout_s)
411
+ except QueryWatchdogEscalationError:
412
+ rt.connection_poisoned = True
413
+ rt.phase = EpisodePhase.FINALIZE
414
+ rt.step_count = max(rt.step_count, rt.budget_steps)
415
+ _LOG.error("episode %s aborted after watchdog escalation", rt.episode_id)
416
+ raise
417
+ rt.last_step_was_tool_error = isinstance(tool_result, ToolError)
418
+ if rt.last_step_was_tool_error:
419
+ rt.consecutive_tool_errors += 1
420
+ else:
421
+ rt.consecutive_tool_errors = 0
422
+
423
+ done = rt.submitted or rt.budget_steps_remaining <= 0
424
+
425
+ obs = SqlDriftObservation(
426
+ step=rt.step_count,
427
+ phase=rt.phase,
428
+ last_tool=action.tool,
429
+ tool_result=tool_result,
430
+ drift_fired=rt.drift_fired,
431
+ drift_acknowledged=rt.drift_acknowledged,
432
+ learned_hints="",
433
+ baseline_sql="",
434
+ schema_synopsis="",
435
+ budget_steps_remaining=rt.budget_steps_remaining,
436
+ reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
437
+ done=done,
438
+ reward=None,
439
+ )
440
+ if rt.drift_acknowledged:
441
+ spec = get_spec(rt.scenario_id)
442
+ obs.learned_hints = self._render_learned_hints(spec, include_drift_cards=True)
443
+
444
+ obs.reward = self._apply_rubric(action, obs)
445
+ if self.rubric is not None:
446
+ obs.reward_components = self.rubric.component_scores()
447
+
448
+ if done and rt.submitted:
449
+ self._maybe_persist_learned_entry()
450
+ return obs
451
+
452
+ def render(self) -> dict[str, Any]:
453
+ """Render the current public state and log the render interaction."""
454
+ rt = self._require_runtime()
455
+ state = self.state
456
+ payload = state.model_dump(mode="json")
457
+ log_interaction(
458
+ event_type="render",
459
+ agent_id=rt.episode_id,
460
+ observation_returned=payload,
461
+ done=rt.submitted or rt.budget_steps_remaining <= 0,
462
+ )
463
+ return payload
464
+
465
+ @property
466
+ def state(self) -> SqlDriftState:
467
+ """Sanitised public state snapshot (explicit whitelist)."""
468
+ rt = self._require_runtime()
469
+ return SqlDriftState(
470
+ episode_id=rt.episode_id,
471
+ step_count=rt.step_count,
472
+ scenario_id=rt.scenario_id,
473
+ phase=rt.phase,
474
+ budget_steps_remaining=rt.budget_steps_remaining,
475
+ drift_fired=rt.drift_fired,
476
+ consultations_used=rt.consultations_used,
477
+ submitted=rt.submitted,
478
+ )
479
+
480
+ def effective_speedup(self) -> float | None:
481
+ """Return the current episode's effective speedup, if any."""
482
+ rt = self._runtime
483
+ if rt is None:
484
+ return None
485
+ return effective_speedup(rt)
486
+
487
+ def close(self) -> None:
488
+ self._close_existing_runtime()
489
+ if self._cleanup_on_close and self._skill_store is not None:
490
+ import shutil
491
+
492
+ store_dir = self._skill_store.dir
493
+ shutil.rmtree(store_dir, ignore_errors=True)
494
+
495
+ # ------------------------------------------------------------------
496
+ # Skill-library wiring
497
+ # ------------------------------------------------------------------
498
+
499
+ def _render_learned_hints(self, spec: ScenarioSpec, *, include_drift_cards: bool = True) -> str:
500
+ playbook, drift_cards = load_all(self._skill_store)
501
+ drift_kind = None
502
+ if include_drift_cards and spec.drift_config is not None:
503
+ drift_kind = spec.drift_config.kind
504
+ result = retrieve(
505
+ query_tags=spec.tags,
506
+ drift_kind=drift_kind,
507
+ playbook=playbook,
508
+ drift_cards=drift_cards,
509
+ )
510
+ return result.render(max_chars=800)
511
+
512
+ def _maybe_persist_learned_entry(self) -> None:
513
+ """Append a PlaybookEntry on terminal success with a meaningful speedup.
514
+
515
+ Failures to persist are logged but never re-raised: a training
516
+ rollout should not crash because the on-disk playbook is under
517
+ contention. The skill store itself is crash-safe (atomic writes
518
+ + file-lock) so at-most-once semantics are sufficient here.
519
+ """
520
+ if self._skill_store is None:
521
+ return
522
+ rt = self._require_runtime()
523
+ if not rt.submitted:
524
+ return
525
+ if self.rubric is None:
526
+ return
527
+ scores = self.rubric.component_scores()
528
+ if scores.get("r_correct", 0.0) < 1.0:
529
+ return
530
+ spec = get_spec(rt.scenario_id)
531
+ raw_speedup = effective_speedup(rt)
532
+ # effective_speedup cannot return None here — rt.submitted is True
533
+ # so submitted_runtime_ms is populated — but we guard defensively.
534
+ # ``+∞`` (drift invalidated the baseline) is capped so the on-disk
535
+ # playbook doesn't serialize ``Infinity``, which would round-trip
536
+ # as a JSON parse error on load.
537
+ if raw_speedup is None or math.isinf(raw_speedup):
538
+ speedup_val = float(SPEEDUP_CAP_FOR_INFTY)
539
+ else:
540
+ speedup_val = float(raw_speedup)
541
+ entry = PlaybookEntry(
542
+ tag_set=spec.tags,
543
+ before_snippet=rt.instance.baseline_sql[:200],
544
+ after_snippet=(rt.submitted_sql or "")[:200],
545
+ avg_speedup=speedup_val,
546
+ scenario_family=spec.family,
547
+ source="learned",
548
+ )
549
+ try:
550
+ self._skill_store.append_playbook(entry)
551
+ except Exception as exc:
552
+ _LOG.warning("skill-library append_playbook failed: %s", exc)
553
+
554
+ # ------------------------------------------------------------------
555
+ # Internal helpers
556
+ # ------------------------------------------------------------------
557
+
558
+ def _grant_step_rebate_once(self, *, attr: str, rebate: float) -> None:
559
+ rt = self._require_runtime()
560
+ if getattr(rt, attr):
561
+ return
562
+ setattr(rt, attr, True)
563
+ rt.last_step_productive_rebate += rebate
564
+
565
+ def _grant_step_rebate_for_table(
566
+ self, *, rewarded_tables_attr: str, table: str, rebate: float
567
+ ) -> None:
568
+ rt = self._require_runtime()
569
+ rewarded = getattr(rt, rewarded_tables_attr)
570
+ if table in rewarded:
571
+ return
572
+ rewarded.add(table)
573
+ rt.last_step_productive_rebate += rebate
574
+
575
+ @staticmethod
576
+ def _pick_scenario_for_seed(seed: int) -> str:
577
+ """Deterministic round-robin over the sorted scenario registry."""
578
+ ids = sorted(REGISTRY)
579
+ if not ids:
580
+ raise RuntimeError("no scenarios registered")
581
+ return ids[seed % len(ids)]
582
+
583
+ def _require_runtime(self) -> RuntimeEpisodeState:
584
+ if self._runtime is None:
585
+ raise RuntimeError("SqlDriftEnvironment.reset() must be called before step()/state.")
586
+ return self._runtime
587
+
588
+ def _close_existing_runtime(self) -> None:
589
+ if self._runtime is not None:
590
+ if self._runtime.connection_poisoned:
591
+ _LOG.error(
592
+ "skipping close for poisoned DuckDB connection in episode %s",
593
+ self._runtime.episode_id,
594
+ )
595
+ else:
596
+ with contextlib.suppress(duckdb.Error):
597
+ self._runtime.conn.close()
598
+ self._runtime = None
599
+
600
+ def _maybe_fire_drift(self) -> None:
601
+ """Apply drift when the step index crosses the schedule/cooldown threshold."""
602
+ rt = self._require_runtime()
603
+ if rt.drift_fired:
604
+ return
605
+ if rt.drift_scheduled_step is None:
606
+ return
607
+ if rt.first_run_query_step is None:
608
+ return
609
+ cfg = rt.instance.drift_config
610
+ assert cfg is not None
611
+ minimum = max(rt.drift_scheduled_step, rt.first_run_query_step + cfg.cooldown_steps)
612
+ if rt.step_count < minimum:
613
+ return
614
+ self._fire_drift()
615
+
616
+ def _fire_drift(self) -> None:
617
+ """Apply drift, author a changelog, and resolve the post-drift GT hash.
618
+
619
+ Failure to recompute the post-drift GT hash is an authoring bug
620
+ (the scenario's ``gt_sql_postdrift`` must execute against the
621
+ just-mutated DB) and we re-raise loudly so it cannot silently
622
+ make every post-drift submission score ``r_correct=0``.
623
+ """
624
+ rt = self._require_runtime()
625
+ cfg = rt.instance.drift_config
626
+ assert cfg is not None
627
+ apply_drift(rt.conn, cfg.kind, cfg.payload)
628
+ rt.drift_fired_step = rt.step_count
629
+ rt.phase = EpisodePhase.DRIFT_RECOVERY
630
+ rt.changelog_entries.append(author_changelog(cfg))
631
+
632
+ try:
633
+ rt.conn.execute(rt.instance.baseline_sql).fetchall()
634
+ rt.baseline_postdrift_raises = False
635
+ except duckdb.Error:
636
+ rt.baseline_postdrift_raises = True
637
+
638
+ if rt.instance.gt_sql_postdrift is not None:
639
+ try:
640
+ rows = rt.conn.execute(rt.instance.gt_sql_postdrift).fetchall()
641
+ except duckdb.Error as exc:
642
+ raise RuntimeError(
643
+ f"scenario {rt.scenario_id!r}: authored gt_sql_postdrift failed "
644
+ f"after drift: {exc}"
645
+ ) from exc
646
+ rt.gt_result_hash_postdrift = canonical_row_hash(rows)
647
+
648
+ # ------------------------------------------------------------------
649
+ # Tool dispatch
650
+ # ------------------------------------------------------------------
651
+
652
+ def _dispatch(self, action: SqlDriftAction, *, timeout_s: float) -> ToolResult:
653
+ payload = action.payload
654
+ try:
655
+ if isinstance(payload, ListTablesPayload):
656
+ return self._handle_list_tables()
657
+ if isinstance(payload, DescribeTablePayload):
658
+ return self._handle_describe_table(payload)
659
+ if isinstance(payload, SampleRowsPayload):
660
+ return self._handle_sample_rows(payload)
661
+ if isinstance(payload, RunQueryPayload):
662
+ return self._handle_run_query(payload, timeout_s=timeout_s)
663
+ if isinstance(payload, ExplainQueryPayload):
664
+ return self._handle_explain_query(payload, timeout_s=timeout_s)
665
+ if isinstance(payload, ReadChangelogPayload):
666
+ return self._handle_read_changelog()
667
+ if isinstance(payload, SubmitRewritePayload):
668
+ return self._handle_submit_rewrite(payload, timeout_s=timeout_s)
669
+ if isinstance(payload, ConsultDBAPayload):
670
+ return self._handle_consult_dba(payload)
671
+ except duckdb.Error as exc:
672
+ return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
673
+ except TimeoutError as exc:
674
+ return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
675
+ # Unreachable — the discriminated-union validator rejects unknown payloads.
676
+ return ToolError(
677
+ code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
678
+ message=f"unknown payload type: {type(payload).__name__}",
679
+ )
680
+
681
+ def _handle_list_tables(self) -> ListTablesResult:
682
+ rt = self._require_runtime()
683
+ rows = rt.conn.execute(
684
+ "SELECT table_name FROM information_schema.tables "
685
+ "WHERE table_schema = 'main' ORDER BY table_name"
686
+ ).fetchall()
687
+ self._grant_step_rebate_once(attr="listed_tables_rewarded", rebate=STEP_REBATE_LIST_TABLES)
688
+ self._mark_diagnostic()
689
+ return ListTablesResult(tables=[r[0] for r in rows])
690
+
691
+ def _handle_describe_table(
692
+ self, payload: DescribeTablePayload
693
+ ) -> DescribeTableResult | ToolError:
694
+ rt = self._require_runtime()
695
+ rows = rt.conn.execute(
696
+ "SELECT column_name, data_type FROM information_schema.columns "
697
+ "WHERE table_name = ? ORDER BY ordinal_position",
698
+ [payload.table],
699
+ ).fetchall()
700
+ if not rows:
701
+ return ToolError(
702
+ code=ToolErrorCode.UNKNOWN_TABLE,
703
+ message=f"unknown table: {payload.table}",
704
+ )
705
+ self._grant_step_rebate_for_table(
706
+ rewarded_tables_attr="described_tables_rewarded",
707
+ table=payload.table,
708
+ rebate=STEP_REBATE_DESCRIBE_TABLE,
709
+ )
710
+ self._mark_diagnostic()
711
+ return DescribeTableResult(
712
+ table=payload.table,
713
+ columns=[{"name": r[0], "type": r[1]} for r in rows],
714
+ )
715
+
716
+ def _handle_sample_rows(self, payload: SampleRowsPayload) -> SampleRowsResult | ToolError:
717
+ rt = self._require_runtime()
718
+ exists = rt.conn.execute(
719
+ "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?",
720
+ [payload.table],
721
+ ).fetchone()
722
+ if not exists or exists[0] == 0:
723
+ return ToolError(
724
+ code=ToolErrorCode.UNKNOWN_TABLE,
725
+ message=f"unknown table: {payload.table}",
726
+ )
727
+ cur = rt.conn.execute(f'SELECT * FROM "{payload.table}" LIMIT {payload.limit}')
728
+ columns = [d[0] for d in cur.description] if cur.description else []
729
+ rows = [list(r) for r in cur.fetchall()]
730
+ self._grant_step_rebate_for_table(
731
+ rewarded_tables_attr="sampled_tables_rewarded",
732
+ table=payload.table,
733
+ rebate=STEP_REBATE_SAMPLE_ROWS,
734
+ )
735
+ self._mark_diagnostic()
736
+ return SampleRowsResult(table=payload.table, columns=columns, rows=rows)
737
+
738
+ def _handle_run_query(
739
+ self, payload: RunQueryPayload, *, timeout_s: float
740
+ ) -> RunQueryResult | ToolError:
741
+ rt = self._require_runtime()
742
+ sql = payload.sql
743
+ try:
744
+ _validate_read_only_sql(sql)
745
+ except ValueError as exc:
746
+ return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
747
+
748
+ # Drift timing: after a valid
749
+ # ``run_query`` attempt, the pre-drift probe invariant is
750
+ # satisfied regardless of whether the execution ultimately
751
+ # returned rows, raised, or was capped for size. Assigning
752
+ # *before* execution means truncation, DB errors, and timeouts
753
+ # can no longer suppress drift firing in later steps.
754
+ if rt.first_run_query_step is None:
755
+ rt.first_run_query_step = rt.step_count
756
+
757
+ try:
758
+ result = execute_once_with_columns(
759
+ rt.conn, sql, timeout_s=timeout_s, max_rows=MAX_RESULT_ROWS
760
+ )
761
+ except TimeoutError as exc:
762
+ return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
763
+ except duckdb.Error as exc:
764
+ # Canonicalize *before* hashing so whitespace-/case-only
765
+ # variants of the same broken query count as the same repeat
766
+ # offence. canonicalize_sql falls back to a whitespace fold
767
+ # for SQL that sqlglot can't parse — still normalises the
768
+ # vast majority of "retried the same typo" cases.
769
+ failure_hash = canonical_row_hash([(canonicalize_sql(sql),)])
770
+ count = rt.failed_query_counts.get(failure_hash, 0) + 1
771
+ rt.failed_query_counts[failure_hash] = count
772
+ rt.failed_query_hashes.add(failure_hash)
773
+ rt.last_step_repeat_failing_query_count = count
774
+ rt.last_step_was_repeat_failing_query = count > 1
775
+ return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
776
+
777
+ if result.truncated:
778
+ return ToolError(
779
+ code=ToolErrorCode.RESULT_TOO_LARGE,
780
+ message=(
781
+ f"result exceeded {MAX_RESULT_ROWS}-row cap — narrow the "
782
+ "projection, add a LIMIT, or aggregate"
783
+ ),
784
+ )
785
+
786
+ self._grant_step_rebate_once(attr="run_query_rewarded", rebate=STEP_REBATE_RUN_QUERY)
787
+ self._mark_diagnostic()
788
+ return RunQueryResult(
789
+ columns=result.columns,
790
+ rows=[list(r) for r in result.rows],
791
+ runtime_ms=result.elapsed_ms,
792
+ row_count=len(result.rows),
793
+ )
794
+
795
+ def _handle_explain_query(
796
+ self, payload: ExplainQueryPayload, *, timeout_s: float
797
+ ) -> ExplainQueryResult | ToolError:
798
+ rt = self._require_runtime()
799
+ try:
800
+ _validate_read_only_sql(payload.sql)
801
+ except ValueError as exc:
802
+ return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
803
+ # EXPLAIN is plan-only (no data materialisation) but we still
804
+ # route it through the watchdog so a pathological query cannot
805
+ # burn the step budget past the caller's wall-clock deadline.
806
+ explain_rows, _ = execute_once_timed(rt.conn, f"EXPLAIN {payload.sql}", timeout_s=timeout_s)
807
+ plan = "\n".join(str(r[-1]) if r else "" for r in explain_rows)
808
+ self._grant_step_rebate_once(
809
+ attr="explain_query_rewarded", rebate=STEP_REBATE_EXPLAIN_QUERY
810
+ )
811
+ self._mark_diagnostic()
812
+ return ExplainQueryResult(plan=plan[:10_000])
813
+
814
+ def _handle_read_changelog(self) -> ReadChangelogResult:
815
+ rt = self._require_runtime()
816
+ if rt.changelog_entries:
817
+ rt.drift_acknowledged = True
818
+ self._grant_step_rebate_once(
819
+ attr="changelog_rewarded_after_drift",
820
+ rebate=STEP_REBATE_READ_CHANGELOG,
821
+ )
822
+ self._mark_diagnostic()
823
+ return ReadChangelogResult(entries=list(rt.changelog_entries))
824
+
825
+ def _handle_submit_rewrite(
826
+ self, payload: SubmitRewritePayload, *, timeout_s: float
827
+ ) -> SubmitRewriteResult | ToolError:
828
+ rt = self._require_runtime()
829
+ if not rt.diagnostic_actions_taken:
830
+ return ToolError(
831
+ code=ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE,
832
+ message=(
833
+ "submit_rewrite rejected: the agent must take at least one "
834
+ "diagnostic action (list_tables, describe_table, sample_rows, "
835
+ "run_query, explain_query, or read_changelog) before submitting."
836
+ ),
837
+ )
838
+ sql = payload.sql
839
+ try:
840
+ _validate_read_only_sql(sql)
841
+ except ValueError as exc:
842
+ return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
843
+ try:
844
+ agent_hash, elapsed_ms = execute_hash_timed(rt.conn, sql, timeout_s=timeout_s)
845
+ except TimeoutError as exc:
846
+ return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
847
+ except duckdb.Error as exc:
848
+ return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
849
+ gt_hash = (
850
+ rt.gt_result_hash_postdrift
851
+ if rt.drift_fired and rt.gt_result_hash_postdrift is not None
852
+ else rt.gt_result_hash_predrift
853
+ )
854
+ matches = agent_hash == gt_hash
855
+
856
+ rt.submitted = True
857
+ rt.submitted_sql = sql
858
+ rt.submitted_sql_canonical = canonicalize_sql(sql)
859
+ rt.submitted_result_hash = agent_hash
860
+ rt.submitted_runtime_ms = elapsed_ms
861
+ rt.phase = EpisodePhase.FINALIZE
862
+ return SubmitRewriteResult(
863
+ accepted=True,
864
+ runtime_ms=elapsed_ms,
865
+ matches_ground_truth=matches,
866
+ )
867
+
868
+ def _handle_consult_dba(self, payload: ConsultDBAPayload) -> ConsultDBAResult | ToolError:
869
+ rt = self._require_runtime()
870
+ if not rt.dba_oracle_enabled:
871
+ return ToolError(
872
+ code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
873
+ message="consult_dba disabled — set enable_dba_oracle=True at reset()",
874
+ )
875
+ if not dba_oracle.has_hints(rt.scenario_id):
876
+ return ToolError(
877
+ code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
878
+ message=f"no DBA hints registered for scenario={rt.scenario_id!r}",
879
+ )
880
+ rt.consultations_used += 1
881
+ tier = min(rt.consultations_used, 3)
882
+ hint = dba_oracle.get_hint(rt.scenario_id, tier)
883
+ del payload # question is free-text context only; hints are scenario-keyed.
884
+ return ConsultDBAResult(tier=tier, hint=hint)
885
+
886
+ def _mark_diagnostic(self) -> None:
887
+ """Record a successful diagnostic tool call and advance the phase machine."""
888
+ rt = self._require_runtime()
889
+ rt.diagnostic_actions_taken += 1
890
+ if rt.phase == EpisodePhase.DIAGNOSE:
891
+ rt.phase = EpisodePhase.REWRITE
892
+
893
+
894
+ __all__ = [
895
+ "DEFAULT_STEP_BUDGET",
896
+ "MAX_RESULT_ROWS",
897
+ "QUERY_TIMEOUT_S",
898
+ "SqlDriftEnvironment",
899
+ ]
skill_library/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-improvement library — pre-seeded playbook + on-disk learned store."""
2
+
3
+ from skill_library.entries import (
4
+ DriftAdaptationCard,
5
+ PlaybookEntry,
6
+ RetrievalResult,
7
+ )
8
+ from skill_library.retrieval import (
9
+ JACCARD_MIN,
10
+ jaccard,
11
+ retrieve,
12
+ top_k_drift_cards,
13
+ top_k_playbook,
14
+ )
15
+ from skill_library.seeds import PRESEED_DRIFT_CARDS, PRESEED_PLAYBOOK
16
+ from skill_library.store import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs
17
+
18
+
19
+ def load_all(
20
+ store: Store | None = None,
21
+ ) -> tuple[
22
+ tuple[PlaybookEntry, ...],
23
+ tuple[DriftAdaptationCard, ...],
24
+ ]:
25
+ """Union of pre-seeds and any entries persisted on disk.
26
+
27
+ Returns ``(playbook, drift_cards)``. Order: pre-seeds first, then
28
+ learned entries, so deterministic retrieval tie-breaks prefer
29
+ the hand-authored pre-seeds when tags and speedup match exactly.
30
+ """
31
+ learned_pb: tuple[PlaybookEntry, ...] = ()
32
+ learned_dc: tuple[DriftAdaptationCard, ...] = ()
33
+ if store is not None:
34
+ learned_pb = store.read_playbook()
35
+ learned_dc = store.read_drift_cards()
36
+ return PRESEED_PLAYBOOK + learned_pb, PRESEED_DRIFT_CARDS + learned_dc
37
+
38
+
39
+ __all__ = [
40
+ "DEFAULT_STORE_DIR",
41
+ "DriftAdaptationCard",
42
+ "JACCARD_MIN",
43
+ "PRESEED_DRIFT_CARDS",
44
+ "PRESEED_PLAYBOOK",
45
+ "PlaybookEntry",
46
+ "RetrievalResult",
47
+ "Store",
48
+ "cleanup_stale_session_dirs",
49
+ "jaccard",
50
+ "load_all",
51
+ "retrieve",
52
+ "top_k_drift_cards",
53
+ "top_k_playbook",
54
+ ]
skill_library/entries.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data classes for the self-improvement (skill) library.
2
+
3
+ Two entry kinds — both immutable dataclasses so they can live on
4
+ frozen-dict caches and be safely shared across episodes.
5
+
6
+ - :class:`PlaybookEntry` — a "before/after" SQL rewrite nugget,
7
+ tagged by anti-pattern + scenario family, with an empirical
8
+ speedup number. Populated by pre-seeds and extended at
9
+ terminal-success (``r_correct > 0 ∧ speedup > 1.2``).
10
+ - :class:`DriftAdaptationCard` — a drift-kind recovery card with a
11
+ symptom regex and a recovery template. Pre-seeded 1-per-drift-kind.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import Literal
18
+
19
+ DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"]
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class PlaybookEntry:
24
+ """A reusable SQL-rewrite recipe keyed by anti-pattern tags."""
25
+
26
+ tag_set: frozenset[str]
27
+ before_snippet: str
28
+ after_snippet: str
29
+ avg_speedup: float
30
+ scenario_family: str # matches scenarios.base.Family
31
+ # Optional provenance — "preseed" for hand-authored, "learned" for
32
+ # append-on-success entries. Used by the retrieval UI only.
33
+ source: Literal["preseed", "learned"] = "preseed"
34
+
35
+ def render_hint(self, max_chars: int = 200) -> str:
36
+ """Render a one-liner suitable for inclusion in ``learned_hints``."""
37
+ body = (
38
+ f"[{self.scenario_family}] "
39
+ f"replace `{self.before_snippet[:60]}...` with "
40
+ f"`{self.after_snippet[:60]}...` "
41
+ f"(~{self.avg_speedup:.1f}x)"
42
+ )
43
+ return body[:max_chars]
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class DriftAdaptationCard:
48
+ """A drift-kind recovery card."""
49
+
50
+ drift_kind: DriftKind
51
+ symptom_regex: str
52
+ recovery_template: str
53
+ success_rate: float = 0.0
54
+ source: Literal["preseed", "learned"] = "preseed"
55
+
56
+ def render_hint(self, max_chars: int = 200) -> str:
57
+ body = (
58
+ f"[drift:{self.drift_kind}] "
59
+ f"symptom=/{self.symptom_regex}/ → "
60
+ f"{self.recovery_template[:120]}"
61
+ )
62
+ return body[:max_chars]
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class RetrievalResult:
67
+ """Top-k blend of playbook hits + drift cards for one retrieval call."""
68
+
69
+ playbook: tuple[PlaybookEntry, ...] = field(default_factory=tuple)
70
+ drift_cards: tuple[DriftAdaptationCard, ...] = field(default_factory=tuple)
71
+
72
+ def render(self, max_chars: int = 800) -> str:
73
+ """Concatenate rendered hints, truncated to ``max_chars``.
74
+
75
+ Deterministic ordering: playbook entries first (by descending
76
+ ``avg_speedup``, ties broken by ``before_snippet``), then drift
77
+ cards (by descending ``success_rate``, ties by ``drift_kind``).
78
+ """
79
+ lines: list[str] = []
80
+ for e in self.playbook:
81
+ lines.append("- " + e.render_hint(max_chars=200))
82
+ for c in self.drift_cards:
83
+ lines.append("- " + c.render_hint(max_chars=200))
84
+ out = "\n".join(lines)
85
+ return out[:max_chars]
86
+
87
+
88
+ __all__ = [
89
+ "DriftAdaptationCard",
90
+ "DriftKind",
91
+ "PlaybookEntry",
92
+ "RetrievalResult",
93
+ ]
skill_library/retrieval.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pure-function Jaccard top-k retrieval over tag sets.
2
+
3
+ No embeddings, no torch, no tokenizer. Deterministic — same inputs
4
+ produce the same ranking, same top-k, same tie-break.
5
+
6
+ A conservative Jaccard threshold (0.3) limits retrieval noise when
7
+ broad pre-seeds would otherwise match every scenario.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from collections.abc import Iterable
13
+
14
+ from skill_library.entries import (
15
+ DriftAdaptationCard,
16
+ PlaybookEntry,
17
+ RetrievalResult,
18
+ )
19
+
20
+ JACCARD_MIN: float = 0.3
21
+
22
+
23
+ def jaccard(a: frozenset[str], b: frozenset[str]) -> float:
24
+ """Standard Jaccard on sets."""
25
+ if not a and not b:
26
+ return 1.0
27
+ if not a or not b:
28
+ return 0.0
29
+ return len(a & b) / len(a | b)
30
+
31
+
32
+ def top_k_playbook(
33
+ query_tags: frozenset[str],
34
+ entries: Iterable[PlaybookEntry],
35
+ k: int = 3,
36
+ *,
37
+ min_overlap: float = JACCARD_MIN,
38
+ ) -> tuple[PlaybookEntry, ...]:
39
+ """Top-k playbook entries by Jaccard overlap with ``query_tags``.
40
+
41
+ Ties broken by descending ``avg_speedup`` then by ``before_snippet``
42
+ lexicographic order so the result is stable across runs.
43
+ """
44
+ scored = [(jaccard(query_tags, e.tag_set), e) for e in entries]
45
+ scored = [(j, e) for j, e in scored if j >= min_overlap]
46
+ scored.sort(key=lambda t: (-t[0], -t[1].avg_speedup, t[1].before_snippet))
47
+ return tuple(e for _, e in scored[:k])
48
+
49
+
50
+ def top_k_drift_cards(
51
+ drift_kind: str | None,
52
+ cards: Iterable[DriftAdaptationCard],
53
+ k: int = 1,
54
+ ) -> tuple[DriftAdaptationCard, ...]:
55
+ """Filter cards by exact drift_kind match, sorted by success_rate desc."""
56
+ if drift_kind is None:
57
+ return ()
58
+ matches = [c for c in cards if c.drift_kind == drift_kind]
59
+ matches.sort(key=lambda c: (-c.success_rate, c.drift_kind))
60
+ return tuple(matches[:k])
61
+
62
+
63
+ def retrieve(
64
+ query_tags: frozenset[str],
65
+ drift_kind: str | None,
66
+ playbook: Iterable[PlaybookEntry],
67
+ drift_cards: Iterable[DriftAdaptationCard],
68
+ *,
69
+ playbook_k: int = 3,
70
+ drift_k: int = 1,
71
+ ) -> RetrievalResult:
72
+ """Combined retrieval: top-k playbook + top-k drift cards."""
73
+ return RetrievalResult(
74
+ playbook=top_k_playbook(query_tags, playbook, k=playbook_k),
75
+ drift_cards=top_k_drift_cards(drift_kind, drift_cards, k=drift_k),
76
+ )
77
+
78
+
79
+ __all__ = [
80
+ "JACCARD_MIN",
81
+ "jaccard",
82
+ "retrieve",
83
+ "top_k_drift_cards",
84
+ "top_k_playbook",
85
+ ]
skill_library/seeds.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hand-authored pre-seed library — 8 playbook entries + 4 drift cards.
2
+
3
+ Tag sets here match the tags on each corresponding scenario
4
+ so a Jaccard top-k against the scenario's tag set returns
5
+ a relevant entry for every one of the 10 scenarios — with two
6
+ generic fallbacks for coverage on novel anti-patterns.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from skill_library.entries import DriftAdaptationCard, PlaybookEntry
12
+
13
+ PRESEED_PLAYBOOK: tuple[PlaybookEntry, ...] = (
14
+ # 1. Correlated subquery → LEFT JOIN + GROUP BY
15
+ PlaybookEntry(
16
+ tag_set=frozenset({"correlated_subquery", "projection_subquery"}),
17
+ before_snippet=(
18
+ "SELECT u.*, (SELECT COUNT(*) FROM orders o WHERE o.user_id=u.id) FROM users u"
19
+ ),
20
+ after_snippet=(
21
+ "SELECT u.*, COALESCE(c.n, 0) FROM users u "
22
+ "LEFT JOIN (SELECT user_id, COUNT(*) n FROM orders GROUP BY user_id) c "
23
+ "ON c.user_id = u.id"
24
+ ),
25
+ avg_speedup=6.0,
26
+ scenario_family="ecommerce",
27
+ ),
28
+ # 2. SELECT * + join → project only needed columns
29
+ PlaybookEntry(
30
+ tag_set=frozenset({"select_star", "over_projection", "join"}),
31
+ before_snippet="SELECT * FROM a JOIN b ON a.id=b.a_id",
32
+ after_snippet="SELECT a.id, a.name, b.amount FROM a JOIN b ON a.id=b.a_id",
33
+ avg_speedup=2.5,
34
+ scenario_family="ecommerce",
35
+ ),
36
+ # 3. Cartesian join — add explicit ON clause
37
+ PlaybookEntry(
38
+ tag_set=frozenset({"cartesian", "missing_join_condition"}),
39
+ before_snippet="SELECT * FROM a, b WHERE a.region = 'US'",
40
+ after_snippet="SELECT a.col FROM a JOIN b ON a.id = b.a_id WHERE a.region = 'US'",
41
+ avg_speedup=50.0,
42
+ scenario_family="events",
43
+ ),
44
+ # 4. DISTINCT on GROUP BY — drop one
45
+ PlaybookEntry(
46
+ tag_set=frozenset({"distinct", "redundant_distinct", "group_by"}),
47
+ before_snippet="SELECT DISTINCT tenant_id, count(*) FROM logs GROUP BY tenant_id",
48
+ after_snippet="SELECT tenant_id, count(*) FROM logs GROUP BY tenant_id",
49
+ avg_speedup=1.4,
50
+ scenario_family="saas_logs",
51
+ ),
52
+ # 5. Nested IN-subquery → JOIN
53
+ PlaybookEntry(
54
+ tag_set=frozenset({"nested_subquery", "in_subquery"}),
55
+ before_snippet="WHERE id IN (SELECT x_id FROM x WHERE ... )",
56
+ after_snippet="JOIN x ON x.x_id = table.id WHERE ...",
57
+ avg_speedup=3.0,
58
+ scenario_family="cms",
59
+ ),
60
+ # 6. HAVING filter on groupable column → push to WHERE
61
+ PlaybookEntry(
62
+ tag_set=frozenset({"having_as_where", "aggregate_filter"}),
63
+ before_snippet="GROUP BY x, status HAVING status = 'fulfilled'",
64
+ after_snippet="WHERE status = 'fulfilled' GROUP BY x",
65
+ avg_speedup=2.0,
66
+ scenario_family="ecommerce",
67
+ ),
68
+ # 7. Generic: prefer JOINs over correlated subqueries
69
+ PlaybookEntry(
70
+ tag_set=frozenset({"subquery", "generic"}),
71
+ before_snippet="scalar subquery in SELECT list",
72
+ after_snippet="LEFT JOIN with aggregated CTE",
73
+ avg_speedup=4.0,
74
+ scenario_family="ecommerce",
75
+ ),
76
+ # 8. Generic: project only used columns
77
+ PlaybookEntry(
78
+ tag_set=frozenset({"over_projection", "generic"}),
79
+ before_snippet="SELECT *",
80
+ after_snippet="SELECT <only needed columns>",
81
+ avg_speedup=1.8,
82
+ scenario_family="ecommerce",
83
+ ),
84
+ )
85
+
86
+
87
+ PRESEED_DRIFT_CARDS: tuple[DriftAdaptationCard, ...] = (
88
+ DriftAdaptationCard(
89
+ drift_kind="column_rename",
90
+ symptom_regex=r'column ".+" does not exist',
91
+ recovery_template=(
92
+ "Read the changelog, update every identifier referencing the old column, and resubmit."
93
+ ),
94
+ success_rate=0.9,
95
+ ),
96
+ DriftAdaptationCard(
97
+ drift_kind="date_format",
98
+ symptom_regex=r"Could not convert string .+ to TIMESTAMP|BIGINT",
99
+ recovery_template=(
100
+ "Epoch-ms columns are BIGINT; cast your filter bounds with "
101
+ "`EXTRACT(EPOCH FROM TIMESTAMP '...') * 1000` or use numeric literals."
102
+ ),
103
+ success_rate=0.85,
104
+ ),
105
+ DriftAdaptationCard(
106
+ drift_kind="enum_rule",
107
+ symptom_regex=r"(empty|zero) result set on filter `... = 'active'`",
108
+ recovery_template=(
109
+ "A single enum value may have been split into several; use `IN "
110
+ "('ACTIVE', 'ACTIVE_V2')` instead of equality."
111
+ ),
112
+ success_rate=0.8,
113
+ ),
114
+ DriftAdaptationCard(
115
+ drift_kind="field_deprecation",
116
+ symptom_regex=r'column ".+" does not exist|non-existent column',
117
+ recovery_template=(
118
+ "The inline field was replaced by a FK; JOIN the lookup table and "
119
+ "project the human-readable name from there."
120
+ ),
121
+ success_rate=0.75,
122
+ ),
123
+ )
124
+
125
+
126
+ __all__ = ["PRESEED_DRIFT_CARDS", "PRESEED_PLAYBOOK"]
skill_library/store.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """JSON-backed playbook and drift-card store with file locking.
2
+
3
+ Each store file is a JSON array. Appends go through a single locked
4
+ read-modify-write cycle that writes to a sibling ``*.tmp`` file and
5
+ atomically ``os.replace``s it onto the target path, so a crash can only
6
+ leave either the old array or the new one — never a truncated file.
7
+
8
+ The lock is held on a dedicated ``*.lock`` file via ``fcntl.flock`` with
9
+ a caller-configurable timeout (default 5s). We never lock the data file
10
+ itself: that way an ``os.replace`` inside the critical section can't
11
+ race against a reader holding a shared lock on the old inode.
12
+
13
+ Reads are cached by mtime so hot-path episodes don't re-parse the file
14
+ on every ``reset()``. Corrupt trailers (from a pre-atomic-write era or
15
+ a partial disk write) log a warning and fall back to empty — we prefer
16
+ a running trainer over one that dies because of a bad card.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import contextlib
22
+ import json
23
+ import os
24
+ import time
25
+ from collections.abc import Callable, Iterator
26
+ from dataclasses import asdict
27
+ from pathlib import Path
28
+ from typing import Any, Literal, TypeVar, cast
29
+
30
+ from skill_library.entries import DriftAdaptationCard, PlaybookEntry
31
+ from utilities.logger import get_module_logger
32
+
33
+ _LOG = get_module_logger(__name__)
34
+
35
+ DEFAULT_STORE_DIR = Path("outputs") / "skill_library"
36
+ PLAYBOOK_FILENAME = "playbook.json"
37
+ DRIFT_CARDS_FILENAME = "drift_cards.json"
38
+ DEFAULT_LOCK_TIMEOUT_S: float = 5.0
39
+
40
+ T = TypeVar("T")
41
+
42
+
43
+ try:
44
+ import fcntl
45
+
46
+ def _try_lock_exclusive(fh: Any) -> bool:
47
+ try:
48
+ fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
49
+ return True
50
+ except BlockingIOError:
51
+ return False
52
+
53
+ def _unlock(fh: Any) -> None:
54
+ fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
55
+
56
+ _HAS_FCNTL = True
57
+
58
+ except ImportError:
59
+ _HAS_FCNTL = False
60
+
61
+ def _try_lock_exclusive(fh: Any) -> bool:
62
+ return True
63
+
64
+ def _unlock(fh: Any) -> None:
65
+ return None
66
+
67
+
68
+ @contextlib.contextmanager
69
+ def _locked(path: Path, timeout_s: float) -> Iterator[None]:
70
+ """Poll-acquire an exclusive flock on ``path`` within ``timeout_s``."""
71
+ path.parent.mkdir(parents=True, exist_ok=True)
72
+ deadline = time.monotonic() + timeout_s
73
+ with path.open("a+") as fh:
74
+ while not _try_lock_exclusive(fh):
75
+ if time.monotonic() >= deadline:
76
+ raise TimeoutError(f"could not acquire {path} within {timeout_s}s")
77
+ time.sleep(0.02)
78
+ try:
79
+ yield
80
+ finally:
81
+ if _HAS_FCNTL:
82
+ _unlock(fh)
83
+
84
+
85
+ def _atomic_write_json(path: Path, payload: list[dict[str, Any]]) -> None:
86
+ tmp = path.with_suffix(path.suffix + ".tmp")
87
+ text = json.dumps(payload, indent=2)
88
+ with tmp.open("w", encoding="utf-8") as fh:
89
+ fh.write(text)
90
+ fh.flush()
91
+ os.fsync(fh.fileno())
92
+ os.replace(tmp, path)
93
+
94
+
95
+ def _read_json_array(path: Path) -> list[Any]:
96
+ if not path.exists():
97
+ return []
98
+ try:
99
+ raw = path.read_text(encoding="utf-8")
100
+ except OSError as exc:
101
+ _LOG.warning("skill-store read failed for %s: %s", path, exc)
102
+ return []
103
+ if not raw.strip():
104
+ return []
105
+ try:
106
+ data = json.loads(raw)
107
+ except json.JSONDecodeError as exc:
108
+ _LOG.warning("skill-store corrupt at %s (%s); returning empty", path, exc)
109
+ return []
110
+ return data if isinstance(data, list) else []
111
+
112
+
113
+ class Store:
114
+ """Append-only JSON store for learned playbook entries + drift cards."""
115
+
116
+ def __init__(
117
+ self,
118
+ directory: Path | None = None,
119
+ lock_timeout_s: float = DEFAULT_LOCK_TIMEOUT_S,
120
+ ) -> None:
121
+ self.dir = Path(directory) if directory is not None else DEFAULT_STORE_DIR
122
+ self.lock_timeout_s = lock_timeout_s
123
+ self._playbook_cache: tuple[tuple[PlaybookEntry, ...], float] | None = None
124
+ self._drift_cache: tuple[tuple[DriftAdaptationCard, ...], float] | None = None
125
+
126
+ def playbook_path(self) -> Path:
127
+ return self.dir / PLAYBOOK_FILENAME
128
+
129
+ def drift_cards_path(self) -> Path:
130
+ return self.dir / DRIFT_CARDS_FILENAME
131
+
132
+ def read_playbook(self) -> tuple[PlaybookEntry, ...]:
133
+ return self._read_cached(
134
+ self.playbook_path(),
135
+ cache_attr="_playbook_cache",
136
+ decode=_entry_from_dict,
137
+ )
138
+
139
+ def read_drift_cards(self) -> tuple[DriftAdaptationCard, ...]:
140
+ return self._read_cached(
141
+ self.drift_cards_path(),
142
+ cache_attr="_drift_cache",
143
+ decode=lambda d: DriftAdaptationCard(**d),
144
+ )
145
+
146
+ def append_playbook(self, entry: PlaybookEntry) -> None:
147
+ self._locked_append(
148
+ self.playbook_path(),
149
+ encode_new=_entry_to_dict,
150
+ new_item=entry,
151
+ )
152
+ self._playbook_cache = None
153
+
154
+ def append_drift_card(self, card: DriftAdaptationCard) -> None:
155
+ self._locked_append(
156
+ self.drift_cards_path(),
157
+ encode_new=asdict,
158
+ new_item=card,
159
+ )
160
+ self._drift_cache = None
161
+
162
+ def _read_cached(
163
+ self,
164
+ path: Path,
165
+ *,
166
+ cache_attr: str,
167
+ decode: Callable[[dict[str, Any]], T],
168
+ ) -> tuple[T, ...]:
169
+ mtime = _safe_mtime(path)
170
+ # ``getattr``/``setattr`` is intentional — the same implementation
171
+ # services both the playbook and drift-card caches, whose Python
172
+ # types differ. The cast below restores the precise
173
+ # ``(tuple[T, ...], float) | None`` shape for mypy.
174
+ cached = cast("tuple[tuple[T, ...], float] | None", getattr(self, cache_attr))
175
+ if cached is not None and cached[1] == mtime:
176
+ return cached[0]
177
+ items: list[T] = []
178
+ for d in _read_json_array(path):
179
+ try:
180
+ items.append(decode(d))
181
+ except (TypeError, KeyError, ValueError) as exc:
182
+ _LOG.warning("skipping malformed store entry %s: %s", d, exc)
183
+ tup = tuple(items)
184
+ setattr(self, cache_attr, (tup, mtime))
185
+ return tup
186
+
187
+ def _locked_append(
188
+ self,
189
+ path: Path,
190
+ *,
191
+ encode_new: Callable[[Any], dict[str, Any]],
192
+ new_item: Any,
193
+ ) -> None:
194
+ self.dir.mkdir(parents=True, exist_ok=True)
195
+ lock_path = path.with_suffix(path.suffix + ".lock")
196
+ with _locked(lock_path, self.lock_timeout_s):
197
+ existing = _read_json_array(path)
198
+ existing.append(encode_new(new_item))
199
+ _atomic_write_json(path, existing)
200
+
201
+
202
+ def _safe_mtime(path: Path) -> float:
203
+ try:
204
+ return path.stat().st_mtime
205
+ except FileNotFoundError:
206
+ return 0.0
207
+
208
+
209
+ def _entry_to_dict(e: PlaybookEntry) -> dict[str, Any]:
210
+ return {
211
+ "tag_set": sorted(e.tag_set),
212
+ "before_snippet": e.before_snippet,
213
+ "after_snippet": e.after_snippet,
214
+ "avg_speedup": e.avg_speedup,
215
+ "scenario_family": e.scenario_family,
216
+ "source": e.source,
217
+ }
218
+
219
+
220
+ def _entry_from_dict(d: dict[str, Any]) -> PlaybookEntry:
221
+ source: Literal["preseed", "learned"] = d.get("source", "learned")
222
+ return PlaybookEntry(
223
+ tag_set=frozenset(d.get("tag_set") or []),
224
+ before_snippet=d["before_snippet"],
225
+ after_snippet=d["after_snippet"],
226
+ avg_speedup=float(d["avg_speedup"]),
227
+ scenario_family=d["scenario_family"],
228
+ source=source,
229
+ )
230
+
231
+
232
+ def cleanup_stale_session_dirs(root: Path, ttl_hours: float) -> int:
233
+ """Remove session subdirectories under *root* whose mtime is older than *ttl_hours*.
234
+
235
+ Returns the number of directories removed. Errors on individual
236
+ subdirectories are logged and skipped so a single bad entry cannot abort
237
+ the sweep. Pass ``ttl_hours=0`` to disable (returns 0 immediately).
238
+ """
239
+ import shutil
240
+
241
+ if ttl_hours <= 0 or not root.exists():
242
+ return 0
243
+ cutoff = time.time() - ttl_hours * 3600
244
+ removed = 0
245
+ for session_dir in root.iterdir():
246
+ if not session_dir.is_dir():
247
+ continue
248
+ try:
249
+ if session_dir.stat().st_mtime < cutoff:
250
+ shutil.rmtree(session_dir, ignore_errors=True)
251
+ removed += 1
252
+ except OSError as exc:
253
+ _LOG.warning("cleanup_stale_session_dirs: skipping %s: %s", session_dir, exc)
254
+ return removed
255
+
256
+
257
+ __all__ = ["DEFAULT_LOCK_TIMEOUT_S", "DEFAULT_STORE_DIR", "Store", "cleanup_stale_session_dirs"]
sql_drift_grpo_training.ipynb ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SQLDrift GRPO Training Notebook\n",
8
+ "\n",
9
+ "Runnable hackathon training artifact for SQLDrift. This notebook uses the repository's existing `training.grpo_train` entrypoint, which wires Unsloth, Hugging Face TRL `GRPOTrainer`, Transformers, PEFT LoRA, and the SQLDrift OpenEnv tool environment.\n",
10
+ "\n",
11
+ "Assumptions:\n",
12
+ "\n",
13
+ "- You run this on a GPU runtime, preferably Colab with an NVIDIA T4/A10/L4 or better.\n",
14
+ "- `SQL_DRIFT_ENV_URL` points at the deployed SQLDrift HF Space, for example `https://your-space.hf.space`.\n",
15
+ "- The notebook is run from the repository root, or `SQL_DRIFT_REPO_URL` is set so Colab can clone the repo.\n",
16
+ "\n",
17
+ "Success criteria:\n",
18
+ "\n",
19
+ "1. The deployed environment health check succeeds.\n",
20
+ "2. GRPO trains for the configured number of steps against live SQLDrift rollouts.\n",
21
+ "3. A LoRA adapter/checkpoint is written under `outputs/checkpoints/grpo_hackathon/`.\n",
22
+ "4. Durable evidence files are written under `training/evidence/`: `grpo_metrics.csv`, `grpo_loss_curve.png`, and `grpo_reward_curve.png`.\n",
23
+ "\n",
24
+ "After pushing this notebook to GitHub, open it in Colab by using:\n",
25
+ "`https://colab.research.google.com/github/<YOUR_GITHUB_USER>/<YOUR_REPO>/blob/main/training/sql_drift_grpo_training.ipynb`\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "## 1. Install The Training Stack\n",
33
+ "\n",
34
+ "If this notebook is launched directly in Colab, set `SQL_DRIFT_REPO_URL` first. If it is already running from the repository root, the clone step is skipped.\n"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "from pathlib import Path\n",
44
+ "import os\n",
45
+ "import subprocess\n",
46
+ "import sys\n",
47
+ "\n",
48
+ "REPO_URL = os.environ.get(\"SQL_DRIFT_REPO_URL\", \"\").strip()\n",
49
+ "repo_root = Path.cwd()\n",
50
+ "\n",
51
+ "if not (repo_root / \"pyproject.toml\").exists():\n",
52
+ " if not REPO_URL:\n",
53
+ " raise RuntimeError(\n",
54
+ " \"Set SQL_DRIFT_REPO_URL to your GitHub repo URL, or run this notebook \"\n",
55
+ " \"from the sql_drift_env repository root.\"\n",
56
+ " )\n",
57
+ " subprocess.check_call([\"git\", \"clone\", REPO_URL, \"sql_drift_env\"])\n",
58
+ " os.chdir(\"sql_drift_env\")\n",
59
+ " repo_root = Path.cwd()\n",
60
+ "\n",
61
+ "subprocess.check_call(\n",
62
+ " [\n",
63
+ " sys.executable,\n",
64
+ " \"-m\",\n",
65
+ " \"pip\",\n",
66
+ " \"install\",\n",
67
+ " \"-q\",\n",
68
+ " \"-e\",\n",
69
+ " \".[train]\",\n",
70
+ " \"matplotlib\",\n",
71
+ " \"pandas\",\n",
72
+ " \"peft\",\n",
73
+ " \"bitsandbytes\",\n",
74
+ " ]\n",
75
+ ")\n",
76
+ "\n",
77
+ "print(f\"Repo root: {repo_root}\")"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "## 2. Point At The Deployed SQLDrift Environment\n",
85
+ "\n",
86
+ "Set `SQL_DRIFT_ENV_URL` to the HF Space URL before running the cell, for example:\n",
87
+ "\n",
88
+ "```python\n",
89
+ "import os\n",
90
+ "os.environ[\"SQL_DRIFT_ENV_URL\"] = \"https://your-space.hf.space\"\n",
91
+ "```\n"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "import json\n",
101
+ "import urllib.request\n",
102
+ "\n",
103
+ "ENV_BASE_URL = os.environ.get(\"SQL_DRIFT_ENV_URL\", \"\").strip()\n",
104
+ "if not ENV_BASE_URL:\n",
105
+ " raise RuntimeError(\"Set SQL_DRIFT_ENV_URL to the deployed SQLDrift HF Space URL.\")\n",
106
+ "\n",
107
+ "health_url = ENV_BASE_URL.rstrip(\"/\") + \"/health\"\n",
108
+ "with urllib.request.urlopen(health_url, timeout=30) as response:\n",
109
+ " body = response.read().decode(\"utf-8\")\n",
110
+ "\n",
111
+ "print(f\"Health check OK: {health_url}\")\n",
112
+ "print(body[:500])"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "metadata": {},
118
+ "source": [
119
+ "## 3. Configure A Minimal GRPO Run\n",
120
+ "\n",
121
+ "Defaults are intentionally small enough for a hackathon Colab run while still producing non-trivial curves. Increase `SQL_DRIFT_GRPO_MAX_STEPS` for final evidence.\n"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "from dataclasses import asdict\n",
131
+ "from pathlib import Path\n",
132
+ "\n",
133
+ "from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig\n",
134
+ "\n",
135
+ "MAX_STEPS = int(os.environ.get(\"SQL_DRIFT_GRPO_MAX_STEPS\", \"80\"))\n",
136
+ "GROUP_SIZE = int(os.environ.get(\"SQL_DRIFT_GRPO_GROUP_SIZE\", \"4\"))\n",
137
+ "SEED = int(os.environ.get(\"SQL_DRIFT_GRPO_SEED\", \"7\"))\n",
138
+ "OUTPUT_DIR = \"outputs/checkpoints/grpo_hackathon\"\n",
139
+ "EVIDENCE_DIR = Path(\"training/evidence\")\n",
140
+ "EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)\n",
141
+ "\n",
142
+ "# P0 audit fixes (design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md \u00a72.6):\n",
143
+ "# - max_seq_length 4096 \u2192 2048 (T4 16 GB OOM safety)\n",
144
+ "# - max_completion_length 1024 \u2192 256 (JSON tool envelopes need ~200 tok)\n",
145
+ "# - max_prompt_length kept large at 1536 (system prompt + tool history)\n",
146
+ "# - save_steps max(20, MAX_STEPS // 4) so a step-79 crash still leaves a checkpoint\n",
147
+ "# - default model swap to unsloth/Qwen3-1.7B-unsloth-bnb-4bit per audit \u00a74\n",
148
+ "cfg = GRPOConfig(\n",
149
+ " model_name=os.environ.get(\"SQL_DRIFT_MODEL_NAME\", \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\"),\n",
150
+ " env_base_url=ENV_BASE_URL,\n",
151
+ " output_dir=OUTPUT_DIR,\n",
152
+ " max_steps=MAX_STEPS,\n",
153
+ " group_size=GROUP_SIZE,\n",
154
+ " gradient_accumulation_steps=int(os.environ.get(\"SQL_DRIFT_GRAD_ACCUM\", \"1\")),\n",
155
+ " learning_rate=float(os.environ.get(\"SQL_DRIFT_LR\", \"5e-6\")),\n",
156
+ " warmup_steps=min(10, max(1, MAX_STEPS // 10)),\n",
157
+ " save_steps=max(20, MAX_STEPS // 4),\n",
158
+ " logging_steps=1,\n",
159
+ " seed=SEED,\n",
160
+ " max_seq_length=2048,\n",
161
+ " max_prompt_length=1536,\n",
162
+ " max_completion_length=256,\n",
163
+ " curriculum=CurriculumConfig(scenarios=ALL_SCENARIOS, mode=\"weighted\", weights=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2)),\n",
164
+ ")\n",
165
+ "\n",
166
+ "(EVIDENCE_DIR / \"grpo_config.json\").write_text(json.dumps(asdict(cfg), indent=2))\n",
167
+ "print(json.dumps(asdict(cfg), indent=2))"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "markdown",
172
+ "metadata": {},
173
+ "source": [
174
+ "## 4. Run GRPO Training\n",
175
+ "\n",
176
+ "This calls `training.grpo_train.train(cfg)`, which loads the Unsloth model, attaches LoRA adapters, builds the curriculum dataset, and lets TRL drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`.\n"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "### 4a. Free CPU smoke (max_steps=2)\n",
184
+ "\n",
185
+ "Pre-flight: validate env URL, dataset build, callback registration, and log file plumbing\n",
186
+ "**before** spending GPU time. P0 fix per `design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md` \u00a75.3.\n",
187
+ "The assertion at the end aborts the notebook if the smoke produced no records."
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "metadata": {},
193
+ "execution_count": null,
194
+ "outputs": [],
195
+ "source": [
196
+ "from dataclasses import replace\n",
197
+ "\n",
198
+ "smoke_cfg = replace(\n",
199
+ " cfg,\n",
200
+ " max_steps=2,\n",
201
+ " save_steps=2,\n",
202
+ " output_dir=str(EVIDENCE_DIR / \"smoke\"),\n",
203
+ ")\n",
204
+ "print(f\"\\n=== SMOKE TEST: {smoke_cfg.max_steps} steps, output={smoke_cfg.output_dir} ===\\n\")\n",
205
+ "\n",
206
+ "from training.grpo_train import train as _train\n",
207
+ "\n",
208
+ "_smoke_trainer = _train(smoke_cfg)\n",
209
+ "_smoke_history = list(getattr(_smoke_trainer.state, \"log_history\", []))\n",
210
+ "assert _smoke_history, \"Smoke produced no log records \u2014 abort before the paid run.\"\n",
211
+ "print(f\"Smoke OK \u2014 {len(_smoke_history)} log records. Proceed to the paid run below.\")"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "from training.grpo_train import train\n",
221
+ "\n",
222
+ "trainer = train(cfg)\n",
223
+ "history = list(getattr(trainer.state, \"log_history\", []))\n",
224
+ "print(f\"Captured {len(history)} trainer log records\")\n",
225
+ "print(f\"Checkpoint directory: {OUTPUT_DIR}\")"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": [
232
+ "## 5. Save Evidence Curves\n",
233
+ "\n",
234
+ "The evaluator requires committed image files, not only WandB links or inline notebook plots. This cell writes PNGs and the raw metrics CSV into `training/evidence/`.\n"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "import pandas as pd\n",
244
+ "import matplotlib.pyplot as plt\n",
245
+ "\n",
246
+ "df = pd.DataFrame(history)\n",
247
+ "if df.empty:\n",
248
+ " raise RuntimeError(\"Trainer log history is empty; training did not emit metrics.\")\n",
249
+ "\n",
250
+ "if \"step\" not in df.columns:\n",
251
+ " df[\"step\"] = range(len(df))\n",
252
+ "\n",
253
+ "metrics_csv = EVIDENCE_DIR / \"grpo_metrics.csv\"\n",
254
+ "df.to_csv(metrics_csv, index=False)\n",
255
+ "\n",
256
+ "\n",
257
+ "def _numeric_columns(frame: pd.DataFrame) -> dict[str, pd.Series]:\n",
258
+ " out: dict[str, pd.Series] = {}\n",
259
+ " for column in frame.columns:\n",
260
+ " series = pd.to_numeric(frame[column], errors=\"coerce\")\n",
261
+ " if series.notna().any():\n",
262
+ " out[column] = series\n",
263
+ " return out\n",
264
+ "\n",
265
+ "\n",
266
+ "def _choose_metric(\n",
267
+ " frame: pd.DataFrame,\n",
268
+ " *,\n",
269
+ " preferred: list[str],\n",
270
+ " required_token: str,\n",
271
+ " exclude_tokens: tuple[str, ...] = (),\n",
272
+ ") -> tuple[str, pd.Series]:\n",
273
+ " numeric = _numeric_columns(frame)\n",
274
+ " for column in preferred:\n",
275
+ " if column in numeric:\n",
276
+ " return column, numeric[column]\n",
277
+ " for column, series in numeric.items():\n",
278
+ " lowered = column.lower()\n",
279
+ " if required_token in lowered and not any(token in lowered for token in exclude_tokens):\n",
280
+ " return column, series\n",
281
+ " available = \", \".join(sorted(numeric))\n",
282
+ " raise RuntimeError(\n",
283
+ " f\"Could not find a numeric {required_token!r} metric. Available: {available}\"\n",
284
+ " )\n",
285
+ "\n",
286
+ "\n",
287
+ "def _plot_curve(column: str, values: pd.Series, ylabel: str, title: str, path: Path) -> None:\n",
288
+ " plot_df = pd.DataFrame(\n",
289
+ " {\n",
290
+ " \"step\": pd.to_numeric(df[\"step\"], errors=\"coerce\"),\n",
291
+ " \"value\": values,\n",
292
+ " }\n",
293
+ " ).dropna()\n",
294
+ " if plot_df.empty:\n",
295
+ " raise RuntimeError(f\"Metric {column!r} has no plottable values.\")\n",
296
+ "\n",
297
+ " plt.figure(figsize=(8, 4.5))\n",
298
+ " plt.plot(plot_df[\"step\"], plot_df[\"value\"], marker=\"o\", linewidth=1.4, label=column)\n",
299
+ " if len(plot_df) >= 5:\n",
300
+ " window = min(10, max(2, len(plot_df) // 5))\n",
301
+ " rolling = plot_df[\"value\"].rolling(window=window, min_periods=1).mean()\n",
302
+ " plt.plot(plot_df[\"step\"], rolling, linewidth=2.2, label=f\"{window}-point rolling mean\")\n",
303
+ " plt.xlabel(\"GRPO step\")\n",
304
+ " plt.ylabel(ylabel)\n",
305
+ " plt.title(title)\n",
306
+ " plt.grid(alpha=0.25)\n",
307
+ " plt.legend()\n",
308
+ " plt.tight_layout()\n",
309
+ " plt.savefig(path, dpi=180, bbox_inches=\"tight\")\n",
310
+ " plt.show()\n",
311
+ "\n",
312
+ "\n",
313
+ "loss_col, loss_values = _choose_metric(\n",
314
+ " df,\n",
315
+ " preferred=[\"loss\", \"train_loss\"],\n",
316
+ " required_token=\"loss\",\n",
317
+ ")\n",
318
+ "reward_col, reward_values = _choose_metric(\n",
319
+ " df,\n",
320
+ " preferred=[\"reward\", \"rewards/mean\", \"mean_reward\", \"reward_mean\", \"train/reward\"],\n",
321
+ " required_token=\"reward\",\n",
322
+ " exclude_tokens=(\"std\", \"variance\"),\n",
323
+ ")\n",
324
+ "\n",
325
+ "loss_png = EVIDENCE_DIR / \"grpo_loss_curve.png\"\n",
326
+ "reward_png = EVIDENCE_DIR / \"grpo_reward_curve.png\"\n",
327
+ "_plot_curve(loss_col, loss_values, \"loss\", \"SQLDrift GRPO Loss Curve\", loss_png)\n",
328
+ "_plot_curve(reward_col, reward_values, \"episode reward\", \"SQLDrift GRPO Reward Curve\", reward_png)\n",
329
+ "\n",
330
+ "print(f\"Wrote metrics: {metrics_csv}\")\n",
331
+ "print(f\"Wrote loss curve: {loss_png}\")\n",
332
+ "print(f\"Wrote reward curve: {reward_png}\")"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "markdown",
337
+ "metadata": {},
338
+ "source": [
339
+ "## 6. Commit The Evidence\n",
340
+ "\n",
341
+ "Once training completes, commit the notebook plus the generated evidence images:\n",
342
+ "\n",
343
+ "```bash\n",
344
+ "git add training/sql_drift_grpo_training.ipynb training/evidence/grpo_loss_curve.png training/evidence/grpo_reward_curve.png training/evidence/grpo_metrics.csv training/evidence/grpo_config.json\n",
345
+ "git commit -m \"Add SQLDrift GRPO training notebook and evidence\"\n",
346
+ "```\n",
347
+ "\n",
348
+ "Do not rely on WandB-only links or plots that live only in notebook output cells; the PNG files above are the durable artifacts evaluators can validate.\n"
349
+ ]
350
+ }
351
+ ],
352
+ "metadata": {
353
+ "kernelspec": {
354
+ "display_name": "Python 3",
355
+ "language": "python",
356
+ "name": "python3"
357
+ },
358
+ "language_info": {
359
+ "codemirror_mode": {
360
+ "name": "ipython",
361
+ "version": 3
362
+ },
363
+ "file_extension": ".py",
364
+ "mimetype": "text/x-python",
365
+ "name": "python",
366
+ "nbconvert_exporter": "python",
367
+ "pygments_lexer": "ipython3",
368
+ "version": "3.12"
369
+ }
370
+ },
371
+ "nbformat": 4,
372
+ "nbformat_minor": 5
373
+ }
training/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TRL/Unsloth GRPO training harness — P12/P13.
2
+
3
+ Public surface:
4
+
5
+ * :class:`training.config.GRPOConfig` / :class:`training.config.CurriculumConfig`
6
+ * :func:`training.prompt.render_system_prompt`
7
+ * :class:`training.random_agent.RandomAgent`
8
+ * :func:`training.grpo_train.train` (requires GPU + ``[train]`` extra)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig
14
+ from training.prompt import (
15
+ render_prompt_from_observation,
16
+ render_system_prompt,
17
+ )
18
+ from training.random_agent import RandomAgent
19
+
20
+ __all__ = [
21
+ "ALL_SCENARIOS",
22
+ "CurriculumConfig",
23
+ "GRPOConfig",
24
+ "RandomAgent",
25
+ "render_prompt_from_observation",
26
+ "render_system_prompt",
27
+ ]
training/config.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training configuration dataclasses.
2
+
3
+ Holds every knob the :mod:`training.grpo_train` script or the eval CLI
4
+ needs, as plain, frozen dataclasses so they serialize cleanly to JSON
5
+ for experiment manifests.
6
+
7
+ Deliberately lightweight: do not import ``trl`` / ``unsloth`` /
8
+ ``transformers`` at module import time. Those libraries are CUDA-heavy
9
+ and optional. ``grpo_train.py`` resolves them lazily.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, field
15
+ from typing import Literal
16
+
17
+ from utilities.env_loader import env_str
18
+
19
+
20
+ def _load_all_scenarios() -> tuple[str, ...]:
21
+ from scenarios import iter_specs
22
+
23
+ return tuple(spec.scenario_id for spec in iter_specs())
24
+
25
+
26
+ # Derived from the live registry so training defaults stay in sync with
27
+ # the scenarios actually shipped under ``scenarios/``.
28
+ ALL_SCENARIOS: tuple[str, ...] = _load_all_scenarios()
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class CurriculumConfig:
33
+ """Scenario sampling policy for GRPO rollouts.
34
+
35
+ ``mode="uniform"`` samples each id in :attr:`scenarios` with equal
36
+ probability. ``mode="weighted"`` uses :attr:`weights` (must be the
37
+ same length as :attr:`scenarios`) — useful for over-sampling drift
38
+ scenarios early in training. ``mode="static_order"`` iterates the
39
+ list round-robin (handy for reproducing eval-style runs).
40
+ """
41
+
42
+ scenarios: tuple[str, ...] = ALL_SCENARIOS
43
+ mode: Literal["uniform", "weighted", "static_order"] = "uniform"
44
+ weights: tuple[float, ...] | None = None
45
+ seed_range: tuple[int, int] = (0, 2**31 - 1)
46
+
47
+ def __post_init__(self) -> None:
48
+ if not self.scenarios:
49
+ raise ValueError("CurriculumConfig.scenarios must be non-empty")
50
+ if self.mode == "weighted":
51
+ if self.weights is None or len(self.weights) != len(self.scenarios):
52
+ raise ValueError("mode='weighted' requires weights of the same length as scenarios")
53
+ if any(w < 0 for w in self.weights):
54
+ raise ValueError("weights must all be >= 0")
55
+ if sum(self.weights) <= 0:
56
+ raise ValueError("at least one weight must be > 0")
57
+ lo, hi = self.seed_range
58
+ if lo < 0 or hi <= lo:
59
+ raise ValueError("seed_range must be (lo >= 0, hi > lo)")
60
+
61
+
62
+ @dataclass(frozen=True)
63
+ class GRPOConfig:
64
+ """Top-level training config for the GRPO skeleton.
65
+
66
+ Defaults: Qwen3-1.7B (Unsloth bnb-4bit) + LoRA r=16, group size 4,
67
+ 500 max steps — but every knob is override-able from the CLI or a
68
+ JSON/YAML manifest. Model swap from Qwen2.5-1.5B → Qwen3-1.7B per
69
+ design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md (section 4: newer family with
70
+ documented agentic tool-calling, comparable footprint).
71
+ """
72
+
73
+ model_name: str = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
74
+ max_seq_length: int = 4096
75
+ load_in_4bit: bool = True
76
+
77
+ lora_r: int = 16
78
+ lora_alpha: int = 16
79
+ lora_dropout: float = 0.0
80
+ lora_target_modules: tuple[str, ...] = (
81
+ "q_proj",
82
+ "k_proj",
83
+ "v_proj",
84
+ "o_proj",
85
+ "gate_proj",
86
+ "up_proj",
87
+ "down_proj",
88
+ )
89
+
90
+ # GRPO knobs
91
+ group_size: int = 4
92
+ learning_rate: float = 5e-6
93
+ max_steps: int = 500
94
+ gradient_accumulation_steps: int = 1
95
+ warmup_steps: int = 10
96
+ max_prompt_length: int = 2048
97
+ max_completion_length: int = 1024
98
+ temperature: float = 0.8
99
+ top_p: float = 0.95
100
+ seed: int = 0
101
+
102
+ # Env wiring
103
+ env_base_url: str = env_str("SQL_DRIFT_ENV_URL", "http://localhost:8000")
104
+ episode_step_budget: int = 25
105
+ dba_oracle_enabled: bool = False
106
+
107
+ # IO
108
+ output_dir: str = "outputs/grpo_run"
109
+ logging_steps: int = 1
110
+ save_steps: int = 100
111
+
112
+ curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
113
+
114
+ def __post_init__(self) -> None:
115
+ if self.group_size < 2:
116
+ raise ValueError("GRPO group_size must be >= 2 (GRPO requires groups).")
117
+ if self.max_steps < 1:
118
+ raise ValueError("max_steps must be >= 1")
119
+ if self.seed < 0:
120
+ raise ValueError("seed must be >= 0")
121
+ if self.lora_r < 1:
122
+ raise ValueError("lora_r must be >= 1")
123
+ if not 0.0 < self.temperature <= 2.0:
124
+ raise ValueError("temperature must be in (0, 2]")
125
+
126
+
127
+ __all__ = ["ALL_SCENARIOS", "CurriculumConfig", "GRPOConfig"]
training/eval.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation harness for SQLDrift.
2
+
3
+ K-rollouts-per-scenario sweep that emits a markdown ``report.md`` and a
4
+ row-level ``per_episode.csv`` so reviewers can eyeball pass rate and
5
+ reward distribution without spinning up a notebook.
6
+
7
+ Invocation::
8
+
9
+ python -m training.eval \
10
+ --checkpoint base \
11
+ --scenarios 1-10 \
12
+ --seeds-per-scenario 5 \
13
+ --out outputs/evals/<run_id>/
14
+
15
+ ``--checkpoint base`` runs the CPU :class:`RandomAgent` as a baseline
16
+ (no LLM, no GPU). A non-"base" value is a pointer to a saved adapter
17
+ and will attempt the lazy Unsloth import path in :func:`load_agent`.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import csv
24
+ import json
25
+ import math
26
+ import statistics
27
+ import time
28
+ from collections.abc import Callable
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Any, Protocol
32
+
33
+ from models import REWARD_COMPONENT_KEYS, SqlDriftAction, SqlDriftObservation
34
+ from server import SqlDriftEnvironment
35
+ from training.config import ALL_SCENARIOS
36
+ from training.random_agent import RandomAgent
37
+ from training.seeding import set_seed
38
+ from utilities.logger import get_module_logger
39
+
40
+ _LOG = get_module_logger(__name__)
41
+
42
+ PASS_REWARD_THRESHOLD = 0.5
43
+ """An episode is a "pass" when its terminal reward (the step that flips
44
+ ``obs.done``—the submit step or the last step before budget exhaustion)
45
+ meets this threshold. The rubric scores a correct submission at
46
+ ``+1.0`` in :attr:`r_correct`, so ``0.5`` matches the
47
+ "correct-but-not-improved" knee. Per-step values accumulate in
48
+ ``episode_return`` and are not used for the pass bit."""
49
+
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Agent interface
53
+ # -----------------------------------------------------------------------------
54
+
55
+
56
+ class Agent(Protocol):
57
+ """Duck-typed policy — :meth:`RandomAgent.act` fits this shape."""
58
+
59
+ def reset(self, seed: int | None = None, scenario_id: str | None = None) -> None: ...
60
+ def act(self, obs: SqlDriftObservation) -> SqlDriftAction: ...
61
+
62
+
63
+ def load_agent(
64
+ checkpoint: str,
65
+ *,
66
+ seed: int = 0,
67
+ base_model: str | None = None,
68
+ temperature: float = 0.0,
69
+ ) -> Agent:
70
+ """Resolve a checkpoint spec to a concrete agent.
71
+
72
+ * ``base``/``random`` — CPU-only :class:`RandomAgent` baseline.
73
+ * Any other value is treated as a filesystem path (a full
74
+ Hugging Face checkpoint directory or a PEFT adapter directory). The
75
+ :class:`training.llm_agent.LLMAgent` is imported lazily so
76
+ CPU-only CI that never calls ``load_agent`` with a path never
77
+ has to install ``transformers``/``peft``.
78
+
79
+ ``base_model`` is forwarded to :class:`LLMAgent` when the adapter
80
+ directory does not pin its base model; ``temperature=0`` (greedy)
81
+ is the default for deterministic eval sweeps.
82
+ """
83
+ if checkpoint in ("base", "random"):
84
+ return RandomAgent(seed=seed)
85
+
86
+ path = Path(checkpoint)
87
+ if not path.exists():
88
+ raise FileNotFoundError(
89
+ f"checkpoint {checkpoint!r} is not 'base'/'random' and does not exist on disk"
90
+ )
91
+
92
+ # Lazy import: keeps the heavy transformers/peft tree out of
93
+ # the default import surface for ``training.eval``.
94
+ from training.llm_agent import LLMAgent
95
+
96
+ return LLMAgent(
97
+ str(path),
98
+ base_model=base_model,
99
+ temperature=temperature,
100
+ seed=seed,
101
+ )
102
+
103
+
104
+ # -----------------------------------------------------------------------------
105
+ # Episode runner
106
+ # -----------------------------------------------------------------------------
107
+
108
+
109
+ @dataclass
110
+ class EpisodeResult:
111
+ scenario_id: str
112
+ seed: int
113
+ terminal_reward: float
114
+ episode_return: float
115
+ steps: int
116
+ passed: bool
117
+ submitted: bool
118
+ drift_fired: bool
119
+ wall_ms: float
120
+ reward_components: dict[str, float] = field(default_factory=dict)
121
+ effective_speedup: float | None = None
122
+
123
+ def as_row(self) -> dict[str, str]:
124
+ row: dict[str, str] = {
125
+ "scenario_id": self.scenario_id,
126
+ "seed": str(self.seed),
127
+ "terminal_reward": f"{self.terminal_reward:.4f}",
128
+ "episode_return": f"{self.episode_return:.4f}",
129
+ "steps": str(self.steps),
130
+ "passed": "1" if self.passed else "0",
131
+ "submitted": "1" if self.submitted else "0",
132
+ "drift_fired": "1" if self.drift_fired else "0",
133
+ "wall_ms": f"{self.wall_ms:.2f}",
134
+ "effective_speedup": _format_speedup(self.effective_speedup),
135
+ }
136
+ for k in REWARD_COMPONENT_KEYS:
137
+ row[k] = f"{self.reward_components.get(k, 0.0):.4f}"
138
+ return row
139
+
140
+
141
+ def _format_speedup(value: float | None) -> str:
142
+ """Render an effective_speedup cell for CSV output.
143
+
144
+ ``None`` (no submission) and ``+∞`` (baseline invalidated by drift)
145
+ need distinct, non-numeric representations so a downstream parser
146
+ cannot conflate "no data" with "infinite" — both read back as empty
147
+ cells today, which would bias per-scenario means.
148
+ """
149
+ if value is None:
150
+ return ""
151
+ if math.isinf(value):
152
+ return "inf"
153
+ return f"{value:.3f}"
154
+
155
+
156
+ def _effective_speedup(env: SqlDriftEnvironment) -> float | None:
157
+ """Read the current episode's effective speedup through the env surface.
158
+
159
+ Kept as a thin shim so tests and reporting code don't need to reach
160
+ into env internals themselves.
161
+ """
162
+ return env.effective_speedup()
163
+
164
+
165
+ def _run_one_episode(
166
+ env: SqlDriftEnvironment,
167
+ agent: Agent,
168
+ *,
169
+ scenario_id: str,
170
+ seed: int,
171
+ max_steps: int = 30,
172
+ ) -> EpisodeResult:
173
+ t0 = time.perf_counter()
174
+ obs = env.reset(seed=seed, scenario_id=scenario_id)
175
+ agent.reset(seed=seed, scenario_id=scenario_id)
176
+
177
+ episode_return = 0.0
178
+ terminal_reward = 0.0
179
+ last_components: dict[str, float] = {}
180
+ steps = 0
181
+ while not obs.done and steps < max_steps:
182
+ action = agent.act(obs)
183
+ obs = env.step(action)
184
+ if obs.reward is not None:
185
+ episode_return += obs.reward
186
+ terminal_reward = obs.reward
187
+ if obs.reward_components:
188
+ last_components = dict(obs.reward_components)
189
+ steps += 1
190
+
191
+ state = env.state
192
+ wall_ms = (time.perf_counter() - t0) * 1000.0
193
+
194
+ return EpisodeResult(
195
+ scenario_id=scenario_id,
196
+ seed=seed,
197
+ terminal_reward=terminal_reward,
198
+ episode_return=episode_return,
199
+ steps=steps,
200
+ passed=terminal_reward >= PASS_REWARD_THRESHOLD,
201
+ submitted=state.submitted,
202
+ drift_fired=state.drift_fired,
203
+ wall_ms=wall_ms,
204
+ reward_components=last_components,
205
+ effective_speedup=_effective_speedup(env),
206
+ )
207
+
208
+
209
+ def run_eval(
210
+ *,
211
+ checkpoint: str,
212
+ scenarios: list[str],
213
+ seeds_per_scenario: int,
214
+ out_dir: Path,
215
+ max_steps: int = 30,
216
+ base_seed: int = 0,
217
+ progress_cb: Callable[[int, int], None] | None = None,
218
+ base_model: str | None = None,
219
+ temperature: float = 0.0,
220
+ ) -> dict[str, Any]:
221
+ """Drive the full sweep. Returns the summary dict also written to JSON.
222
+
223
+ ``base_model`` / ``temperature`` are forwarded to
224
+ :func:`load_agent` for LLM-checkpoint runs (ignored for
225
+ ``base``/``random``).
226
+ """
227
+ out_dir.mkdir(parents=True, exist_ok=True)
228
+ set_seed(base_seed)
229
+ agent = load_agent(
230
+ checkpoint,
231
+ seed=base_seed,
232
+ base_model=base_model,
233
+ temperature=temperature,
234
+ )
235
+
236
+ results: list[EpisodeResult] = []
237
+ env = SqlDriftEnvironment()
238
+ total = len(scenarios) * seeds_per_scenario
239
+ done = 0
240
+ try:
241
+ for scenario_id in scenarios:
242
+ for k in range(seeds_per_scenario):
243
+ seed = base_seed + k
244
+ res = _run_one_episode(
245
+ env,
246
+ agent,
247
+ scenario_id=scenario_id,
248
+ seed=seed,
249
+ max_steps=max_steps,
250
+ )
251
+ results.append(res)
252
+ done += 1
253
+ if progress_cb:
254
+ progress_cb(done, total)
255
+ finally:
256
+ env.close()
257
+
258
+ _write_per_episode_csv(results, out_dir / "per_episode.csv")
259
+ summary = _build_summary(
260
+ results,
261
+ checkpoint=checkpoint,
262
+ scenarios=scenarios,
263
+ seeds_per_scenario=seeds_per_scenario,
264
+ )
265
+ (out_dir / "summary.json").write_text(json.dumps(summary, indent=2))
266
+ (out_dir / "report.md").write_text(render_report(summary, results))
267
+ return summary
268
+
269
+
270
+ # -----------------------------------------------------------------------------
271
+ # Reporting
272
+ # -----------------------------------------------------------------------------
273
+
274
+
275
+ def _write_per_episode_csv(results: list[EpisodeResult], path: Path) -> None:
276
+ if not results:
277
+ path.write_text("")
278
+ return
279
+ fieldnames = list(results[0].as_row().keys())
280
+ with path.open("w", newline="") as f:
281
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
282
+ writer.writeheader()
283
+ for r in results:
284
+ writer.writerow(r.as_row())
285
+
286
+
287
+ def _speedup_stats(
288
+ speedups: list[float],
289
+ ) -> tuple[float | None, int]:
290
+ """Finite-mean + ``+∞`` count for a list of effective_speedup values.
291
+
292
+ ``statistics.fmean`` on a list containing ``+∞`` returns ``+∞``,
293
+ which poisons JSON output (``json.dumps`` rejects ``Infinity`` by
294
+ default) and washes out the finite distribution. We split the two
295
+ channels: the reported mean is over *finite* speedups only, and the
296
+ ``infinite_count`` is surfaced separately so reviewers can see how
297
+ often drift invalidated the baseline.
298
+ """
299
+ finite = [s for s in speedups if not math.isinf(s)]
300
+ infinite_count = len(speedups) - len(finite)
301
+ mean = statistics.fmean(finite) if finite else None
302
+ return mean, infinite_count
303
+
304
+
305
+ def _build_summary(
306
+ results: list[EpisodeResult],
307
+ *,
308
+ checkpoint: str,
309
+ scenarios: list[str],
310
+ seeds_per_scenario: int,
311
+ ) -> dict[str, Any]:
312
+ by_scenario: dict[str, dict[str, Any]] = {}
313
+ for sid in scenarios:
314
+ scoped = [r for r in results if r.scenario_id == sid]
315
+ if not scoped:
316
+ continue
317
+ terminals = [r.terminal_reward for r in scoped]
318
+ returns = [r.episode_return for r in scoped]
319
+ speedups = [r.effective_speedup for r in scoped if r.effective_speedup is not None]
320
+ mean_sp, inf_sp = _speedup_stats(speedups)
321
+ by_scenario[sid] = {
322
+ "n": len(scoped),
323
+ "pass_rate": sum(1 for r in scoped if r.passed) / len(scoped),
324
+ "mean_terminal_reward": statistics.fmean(terminals),
325
+ "std_terminal_reward": statistics.pstdev(terminals) if len(terminals) > 1 else 0.0,
326
+ "mean_episode_return": statistics.fmean(returns),
327
+ "submit_rate": sum(1 for r in scoped if r.submitted) / len(scoped),
328
+ "mean_effective_speedup": mean_sp,
329
+ "infinite_speedup_count": inf_sp,
330
+ }
331
+
332
+ all_speedups = [r.effective_speedup for r in results if r.effective_speedup is not None]
333
+ mean_sp_all, inf_sp_all = _speedup_stats(all_speedups)
334
+ overall = {
335
+ "checkpoint": checkpoint,
336
+ "n_episodes": len(results),
337
+ "seeds_per_scenario": seeds_per_scenario,
338
+ "pass_rate": sum(1 for r in results if r.passed) / len(results) if results else 0.0,
339
+ "mean_terminal_reward": (
340
+ statistics.fmean(r.terminal_reward for r in results) if results else 0.0
341
+ ),
342
+ "std_terminal_reward": (
343
+ statistics.pstdev([r.terminal_reward for r in results]) if len(results) > 1 else 0.0
344
+ ),
345
+ "mean_episode_return": (
346
+ statistics.fmean(r.episode_return for r in results) if results else 0.0
347
+ ),
348
+ "submit_rate": sum(1 for r in results if r.submitted) / len(results) if results else 0.0,
349
+ "mean_effective_speedup": mean_sp_all,
350
+ "infinite_speedup_count": inf_sp_all,
351
+ }
352
+ return {"overall": overall, "by_scenario": by_scenario}
353
+
354
+
355
+ def render_report(
356
+ summary: dict[str, Any],
357
+ results: list[EpisodeResult],
358
+ ) -> str:
359
+ """Compose a reviewer-friendly ``report.md`` string."""
360
+ overall = summary["overall"]
361
+ lines: list[str] = []
362
+ lines.append("# SQLDrift evaluation report")
363
+ lines.append("")
364
+ speedup = overall.get("mean_effective_speedup")
365
+ speedup_cell = f"{speedup:.2f}x" if speedup is not None else "—"
366
+ lines.append(f"- Checkpoint: `{overall['checkpoint']}`")
367
+ lines.append(f"- Episodes: **{overall['n_episodes']}**")
368
+ lines.append(f"- Seeds/scenario: {overall['seeds_per_scenario']}")
369
+ lines.append(
370
+ f"- Overall pass rate (terminal reward ≥ {PASS_REWARD_THRESHOLD}): "
371
+ f"**{overall['pass_rate']:.1%}**"
372
+ )
373
+ lines.append(
374
+ f"- Mean terminal reward: **{overall['mean_terminal_reward']:.3f}** "
375
+ f"(σ = {overall['std_terminal_reward']:.3f})"
376
+ )
377
+ lines.append(f"- Mean episode return: {overall['mean_episode_return']:.3f}")
378
+ lines.append(f"- Submit rate: {overall['submit_rate']:.1%}")
379
+ lines.append(f"- Mean effective speedup (finite, submitted episodes): **{speedup_cell}**")
380
+ inf_count = overall.get("infinite_speedup_count", 0)
381
+ if inf_count:
382
+ lines.append(f"- Infinite-speedup episodes (drift invalidated baseline): **{inf_count}**")
383
+ lines.append("")
384
+
385
+ lines.append("## Per-scenario")
386
+ lines.append("")
387
+ lines.append("| Scenario | N | Pass | Terminal μ | Return μ | Submit | Speedup |")
388
+ lines.append("|----------|---|------|-----------|----------|--------|---------|")
389
+ for sid, row in summary["by_scenario"].items():
390
+ sp = row.get("mean_effective_speedup")
391
+ sp_cell = f"{sp:.2f}x" if sp is not None else "—"
392
+ lines.append(
393
+ f"| `{sid}` | {row['n']} | {row['pass_rate']:.0%} | "
394
+ f"{row['mean_terminal_reward']:.3f} | {row['mean_episode_return']:.3f} | "
395
+ f"{row['submit_rate']:.0%} | {sp_cell} |"
396
+ )
397
+ lines.append("")
398
+
399
+ lines.append("## Reward-component bars (mean across episodes)")
400
+ lines.append("")
401
+ lines.append(_render_component_bars(results))
402
+ lines.append("")
403
+
404
+ lines.append("## Notes")
405
+ lines.append("")
406
+ lines.append(
407
+ "- Random baseline does NOT attempt rewrites intelligently; "
408
+ "nonzero pass rate here is a lower bound on a well-trained agent."
409
+ )
410
+ lines.append(
411
+ "- Pass threshold is set at reward ≥ "
412
+ f"{PASS_REWARD_THRESHOLD}, matching the rubric's "
413
+ '"correct-but-not-improved" +0.5 partial credit (rubric).'
414
+ )
415
+ return "\n".join(lines) + "\n"
416
+
417
+
418
+ def _render_component_bars(results: list[EpisodeResult]) -> str:
419
+ if not results:
420
+ return "_no data_"
421
+ sums: dict[str, float] = {k: 0.0 for k in REWARD_COMPONENT_KEYS}
422
+ counts: dict[str, int] = {k: 0 for k in REWARD_COMPONENT_KEYS}
423
+ for r in results:
424
+ for k, v in r.reward_components.items():
425
+ if k in sums:
426
+ sums[k] += v
427
+ counts[k] += 1
428
+ means = {k: (sums[k] / counts[k] if counts[k] else 0.0) for k in sums}
429
+
430
+ # ASCII bar with sign (+/-).
431
+ max_abs = max((abs(v) for v in means.values()), default=1.0) or 1.0
432
+ width = 30
433
+ lines = ["```"]
434
+ for k in REWARD_COMPONENT_KEYS:
435
+ v = means[k]
436
+ bar_len = int(round(abs(v) / max_abs * width))
437
+ bar = ("█" * bar_len) if v >= 0 else ("▒" * bar_len)
438
+ lines.append(f"{k:<14} {v:+7.3f} {bar}")
439
+ lines.append("```")
440
+ return "\n".join(lines)
441
+
442
+
443
+ # -----------------------------------------------------------------------------
444
+ # CLI
445
+ # -----------------------------------------------------------------------------
446
+
447
+
448
+ def _expand_scenarios(spec: str) -> list[str]:
449
+ """Accept either ``1-10``, ``1,3,5``, or a comma-list of raw ids."""
450
+ if "-" in spec and all(part.isdigit() for part in spec.split("-")):
451
+ lo, hi = (int(x) for x in spec.split("-"))
452
+ want_range: set[int] = set(range(lo, hi + 1))
453
+ return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_range]
454
+ if all(part.strip().isdigit() for part in spec.split(",")):
455
+ want_set = {int(part) for part in spec.split(",")}
456
+ return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_set]
457
+ return [s.strip() for s in spec.split(",")]
458
+
459
+
460
+ def main(argv: list[str] | None = None) -> None:
461
+ ap = argparse.ArgumentParser(description="Evaluate an agent on SQLDrift.")
462
+ ap.add_argument("--checkpoint", required=True, help="'base' or adapter path")
463
+ ap.add_argument(
464
+ "--scenarios",
465
+ default="1-10",
466
+ help="Scenario range/list (e.g. '1-10', '1,3,5') or raw ids",
467
+ )
468
+ ap.add_argument("--seeds-per-scenario", type=int, default=5)
469
+ ap.add_argument("--out", type=Path, required=True)
470
+ ap.add_argument("--max-steps", type=int, default=30)
471
+ ap.add_argument("--base-seed", type=int, default=0)
472
+ ap.add_argument(
473
+ "--base-model",
474
+ default=None,
475
+ help="Override the base model id when --checkpoint points at a PEFT adapter",
476
+ )
477
+ ap.add_argument(
478
+ "--temperature",
479
+ type=float,
480
+ default=0.0,
481
+ help="Decoding temperature for LLM checkpoints (0 = greedy / deterministic).",
482
+ )
483
+ ns = ap.parse_args(argv)
484
+
485
+ scenarios = _expand_scenarios(ns.scenarios)
486
+ if not scenarios:
487
+ raise SystemExit(f"no scenarios matched spec {ns.scenarios!r}")
488
+
489
+ def _prog(done: int, total: int) -> None:
490
+ if done == total or done % max(1, total // 10) == 0:
491
+ _LOG.info("eval: %d/%d episodes", done, total)
492
+
493
+ summary = run_eval(
494
+ checkpoint=ns.checkpoint,
495
+ scenarios=scenarios,
496
+ seeds_per_scenario=ns.seeds_per_scenario,
497
+ out_dir=ns.out,
498
+ max_steps=ns.max_steps,
499
+ base_seed=ns.base_seed,
500
+ progress_cb=_prog,
501
+ base_model=ns.base_model,
502
+ temperature=ns.temperature,
503
+ )
504
+ print(json.dumps(summary["overall"], indent=2))
505
+
506
+
507
+ if __name__ == "__main__":
508
+ main()
509
+
510
+
511
+ __all__ = [
512
+ "Agent",
513
+ "EpisodeResult",
514
+ "PASS_REWARD_THRESHOLD",
515
+ "load_agent",
516
+ "main",
517
+ "render_report",
518
+ "run_eval",
519
+ ]
training/evidence/grpo_components_combined.png ADDED

Git LFS Details

  • SHA256: 48462be701ddc14eb7fd4933a9ad1c94c9b5a3692169e9d98dc56ab12a78baaa
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
training/evidence/grpo_loss_curve.png ADDED
training/evidence/grpo_metrics.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ step,loss,reward,r_correct,r_drift,r_step_tax
2
+ 1,1.5070857702340332,-1.089657810470115,0.0060142775960789485,-0.05147482162696701,-0.5153269778837614
3
+ 2,1.4163939988607464,-0.8820803512725144,0.03844503923645059,-0.10545186822296794,-0.4861222451750179
4
+ 3,1.4296624851767596,-0.9146981630547505,0.10331660081021651,0.07228358802913387,-0.4961777953169654
5
+ 4,1.3142921668473637,-0.5595880667701431,0.08344587035631848,0.059823292014677544,-0.520072970343028
6
+ 5,1.2701518625957149,-0.5996319117518063,0.2610718206538336,0.004943992480230808,-0.4840739746024923
7
+ 6,1.2640116521205307,-0.8258103297724467,0.26761100025304996,0.1913788646035109,-0.49710482987102533
8
+ 7,1.1209233460223358,-0.6211756254418923,0.27404383444577135,0.200216496595977,-0.4682072435811402
9
+ 8,1.1774670143175638,-0.619087388960126,0.270469688082585,0.17365149920126927,-0.4864816492894405
10
+ 9,1.1983245069119268,-0.5491623820497955,0.3851815028189252,0.14236176514396473,-0.47340387335667083
11
+ 10,1.1900131680220651,-0.4531450217747221,0.46035015896826725,0.20935614488999046,-0.3977678494548791
12
+ 11,1.0778750024182207,-0.5529804804536397,0.4784254548477095,0.25126495579288227,-0.4914229994698512
13
+ 12,1.0142047885469356,-0.30198009788159785,0.5200582520058386,0.24302822958435058,-0.46425248430094396
14
+ 13,1.0045646774653714,-0.4645827945838715,0.5862339331205229,0.27468662267016536,-0.4534963800694839
15
+ 14,0.9116167040003166,-0.3810888813065016,0.5319641343757862,0.2831807948097964,-0.43758017046734105
16
+ 15,0.9119484849618256,-0.19031109272031552,0.544487836273619,0.30242310302170133,-0.43358305018602905
17
+ 16,0.8631872796395779,-0.1902270628270999,0.5378009235565281,0.46520136753468744,-0.4378038803670464
18
+ 17,0.9108494430815297,-0.28522550174067374,0.6317914174443212,0.39742825226075135,-0.4194535022660978
19
+ 18,0.7414210838734385,-0.026510480876009573,0.6298982943138216,0.49275288649025134,-0.4263594599419731
20
+ 19,0.8025799827262751,-0.09087146161792733,0.69004717240886,0.4924557349195555,-0.3603965798497102
21
+ 20,0.6973776508874913,0.011148493528077183,0.7800312956412032,0.456173588246084,-0.4150265290706628
training/evidence/grpo_r_correct_curve.png ADDED
training/evidence/grpo_r_drift_curve.png ADDED