diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..238a5dff8ae8336669e310aad05572ba396936fb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +training/evidence/grpo_components_combined.png filter=lfs diff=lfs merge=lfs -text +training/evidence/grpo_reward_curve.png filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9a1fb2bee14ecdbdd063279081ee3c3b0a554a83 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build using openenv-base +# This Dockerfile is flexible and works for both: +# - In-repo environments (with local OpenEnv sources) +# - Standalone environments (with openenv from PyPI/Git) +# The build script (openenv build) handles context detection and sets appropriate build args. + +# ─── IMAGE PINNING ──────────────────────────────────────────────────────────── +# Using :latest is a moving target: a base-image update between local build +# and judging rebuild can pull a different Python / CUDA / OS version. +# +# TO PIN (run once, commit the digest): +# docker pull ghcr.io/meta-pytorch/openenv-base:latest +# docker inspect --format='{{index .RepoDigests 0}}' ghcr.io/meta-pytorch/openenv-base:latest +# # → ghcr.io/meta-pytorch/openenv-base@sha256: +# +# Then replace the ARG default below with the digest, e.g.: +# ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base@sha256: +# +# The sha256 digest is immutable; the image will not change between builds. +# ────────────────────────────────────────────────────────────────────────────── +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + +# Ensure git is available (required for installing dependencies from VCS) +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +# Build argument to control whether we're building standalone or in-repo +ARG BUILD_MODE=in-repo +ARG ENV_NAME=sql_drift_env + +# Copy environment code (always at root of build context) +COPY . /app/env + +# For in-repo builds, openenv is already vendored in the build context +# For standalone builds, openenv will be installed via pyproject.toml +WORKDIR /app/env + +# Ensure uv is available (for local builds where base image lacks it) +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +# Install dependencies from the frozen lock file. Failing fast when uv.lock +# is absent prevents a silent re-resolution that could pull different versions +# than the ones tested (see design/codereview.md, dependency pins). +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --frozen --no-install-project --no-editable + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --frozen --no-editable + +# Final runtime stage — inherits the same pinned/latest image chosen above. +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Copy the uv-managed Python interpreter from the builder so that the +# shebangs baked into ``/app/env/.venv/bin/*`` (e.g. uvicorn) resolve. +# Without this, execve() on those scripts silently falls back to the +# system ``/usr/local/bin/uvicorn`` (Python 3.11) and loads the wrong +# site-packages, yielding ``ModuleNotFoundError: duckdb`` at startup. +COPY --from=builder /root/.local/share/uv /root/.local/share/uv + +# Copy the environment code (includes the builder's .venv at /app/env/.venv) +COPY --from=builder /app/env /app/env + +# Point PATH + PYTHONPATH at the venv that shipped with the env so +# both ``uvicorn ...`` and ``python -m uvicorn ...`` resolve against +# the right interpreter. +ENV PATH="/app/env/.venv/bin:$PATH" +ENV PYTHONPATH="/app/env:$PYTHONPATH" + +# Health check (hits the stateless /health endpoint exposed by the +# OpenEnv FastAPI factory). +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server via the venv's Python interpreter directly +# (bypassing any PATH-lookup ambiguity with the base image's system +# uvicorn). +ENV ENABLE_WEB_INTERFACE=true +CMD ["/app/env/.venv/bin/python", "-m", "uvicorn", "server.app:app", \ + "--host", "0.0.0.0", "--port", "8000", "--app-dir", "/app/env"] diff --git a/README.md b/README.md index 53393322d258099e2e00c5b0e47b39942838713a..8d23197c0fae3d197668c1af035c1fc3eae63a1c 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,155 @@ --- -title: Sql Drift Env -emoji: 📚 -colorFrom: pink -colorTo: red +title: SQLDrift +emoji: 🐘 +colorFrom: blue +colorTo: indigo sdk: docker -pinned: false +app_port: 8000 +short_description: SQL repair and optimization under live schema drift +base_path: /web --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# SQLDrift + +> An OpenEnv gym that teaches an LLM agent to **repair and optimize SQL** when +> the database schema and business rules shift out from under it. + +SQLDrift is a production-grade [OpenEnv](https://github.com/meta-pytorch/OpenEnv) +environment designed for RL fine-tuning with TRL's +[`GRPOTrainer`](https://huggingface.co/docs/trl/openenv) (plus +[Unsloth](https://github.com/unslothai/unsloth) LoRA on consumer GPUs). +Each episode hands the agent a slow baseline SQL query, a lightly-populated +DuckDB database, a toolbox of read-only introspection / execution tools, and a +**25-step budget**. Midway through the episode a schema or business-rule +**drift event** may fire; the agent must read the deploy changelog, re-diagnose, +and submit a rewrite that is both (a) semantically correct and (b) ≥ 1.2× faster +than the baseline. A hierarchical rubric turns that into six reward signals +(correctness, drift-adapt, speedup, step-tax, gatekeepers, consult-dba). + +Highlights: + +- **10 hand-crafted scenarios** — 6 classic anti-patterns (correlated + subqueries, `SELECT *` joins, Cartesian joins, `DISTINCT` groupbys, nested + subqueries, `HAVING`-as-`WHERE`) and 4 drift events (column rename, date + format flip, enum rule split, field deprecation). +- **Deterministic fixtures** — 15–30-column schemas, 2–4 tables per scenario, + regenerated in-process at `reset()` from a seed. No Parquet files, no + pre-baked sqlite dumps. +- **Sqlglot-canonicalised baseline-verbatim gate** — agents that rename + whitespace and resubmit the baseline cannot farm the +1.0 correct bonus. +- **Skill library** — 12 pre-seeded playbook/drift-card entries plus an + on-disk, `fcntl.flock`-guarded JSON store that grows as the agent solves + episodes, surfaced next `reset()` via Jaccard retrieval. +- **Feature-flagged DBA Oracle** — 3-tier escalating hints per scenario, + penalised by the `ConsultDBA` rubric. Off by default. +- **Dockerised** — `server/Dockerfile` layers the env on + `ghcr.io/meta-pytorch/openenv-base` and exposes `/health`, `/reset`, + `/step`, `/ws` out-of-the-box. + +## Quick start + +```bash +git clone /sql_drift_env.git +cd sql_drift_env +uv sync # install deps +uv run pytest -q # 300+ tests, ~90s on CPU +uv run uvicorn server.app:app --reload # serve the env on :8000 +``` + +Or via Docker: + +```bash +docker build -f server/Dockerfile -t sql-drift-env:latest . +docker run -p 8000:8000 sql-drift-env:latest +curl -s http://localhost:8000/health +``` + +## Programmatic rollout + +A five-step in-process episode against scenario `01_correlated_subquery`: + +```python +from client import SqlDriftEnv +from server import SqlDriftEnvironment + +env = SqlDriftEnvironment() +obs = env.reset(seed=42, scenario_id="01_correlated_subquery") +print(obs.learned_hints) # surfaced from skill library + +obs = env.step(SqlDriftEnv.action_list_tables()) +obs = env.step(SqlDriftEnv.action_describe_table("users")) +obs = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM users")) +obs = env.step(SqlDriftEnv.action_submit_rewrite( + "SELECT u.*, COALESCE(c.n, 0) FROM users u " + "LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders GROUP BY 1) c " + "ON c.user_id = u.user_id" +)) +print(obs.reward, obs.reward_components) +env.close() +``` + +End-to-end over an HTTP+WS OpenEnv server, see `SqlDriftEnv` in +`client.py` and the integration test suite in +`tests/integration/test_client_server.py` / +`tests/integration/test_state_no_leak.py`. + +## Evaluation + +```bash +uv run python -m training.eval \ + --checkpoint base \ + --scenarios 1-10 \ + --seeds-per-scenario 5 \ + --out outputs/evals/my_run +``` + +Emits `report.md`, `per_episode.csv`, and `summary.json`. A baseline random +agent sweep is checked in at `outputs/evals/baseline_random_v1/report.md` as a +floor: mean reward ≈ **-2.05**, pass rate **0 %**. + +## Training (GPU) + +`training/grpo_train.py` contains the GRPO entrypoint used by the hackathon +training notebook: it builds the curriculum dataset, loads +`unsloth/Qwen3-1.7B-unsloth-bnb-4bit`, attaches PEFT LoRA adapters, and lets TRL's +`GRPOTrainer` drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`. +Run `training/sql_drift_grpo_training.ipynb` on a GPU runtime with +`SQL_DRIFT_ENV_URL` pointed at the deployed SQLDrift Space; the notebook writes +durable evidence plots to `training/evidence/`. + +```bash +uv sync --extra train # installs trl, transformers, accelerate, unsloth +uv sync --extra evidence # matplotlib + pandas for utilities/plot_curves.py +``` + +## Repository layout + +``` +sql_drift_env/ +├── models.py # Pydantic v2 action/observation/state +├── client.py # SqlDriftEnv EnvClient (/ws) +├── engine/ +│ ├── runtime.py # private RuntimeEpisodeState +│ ├── drift.py # 4 DDL drift operators +│ ├── reward.py # SqlDriftRubric (6 child rubrics) +│ ├── verifier.py / profiler.py +├── scenarios/ # 10 hand-crafted scenario modules + registry +├── skill_library/ # pre-seeds + JSON store + Jaccard retrieval +├── actors/ # engineering_manager (changelog), dba_oracle +├── server/ # FastAPI app, Dockerfile, env wrapper class +├── training/ # config, prompt, random_agent, grpo_train, eval +├── utilities/ # env_loader, logger, plot_curves, run_training_job (HF Jobs), … +├── tests/ # 300+ unit + integration tests +├── outputs/evals/ # baseline eval artifacts +└── design/ # design docs (ignored by docker) +``` + +## Design references + +- `.cursor/plans/sqldrift_technical_spec_feb249d1.plan.md` — full 16-phase + technical spec (Rev-3). +- `design/CLAUDE.md` — agent coding guidelines. + +## License + +BSD-3-Clause (matches the OpenEnv project template). diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..201ccfa49dedf310a90685b4f35137fcdce90505 --- /dev/null +++ b/__init__.py @@ -0,0 +1,29 @@ +"""SQLDrift — OpenEnv gym for SQL repair & optimization under drift. + +The repo uses a flat top-level module layout (``models.py``, +``client.py``, ``engine/``, ``scenarios/``, ``skill_library/``, +``actors/``, ``server/``, ``training/``) because it is also run as a +FastAPI server that imports siblings absolutely (``from models import +…``). setuptools republishes ``.`` as the ``sql_drift_env`` package so +both import styles work at runtime, but eagerly re-exporting the flat +submodules from here would shadow the top-level ``import models`` / +``import client`` paths that every flat module relies on, and would +make the import order pytest-collection-sensitive. + +The public API for agent code is therefore the flat modules themselves, +imported directly: + + from client import SqlDriftEnv + from models import SqlDriftAction, SqlDriftObservation + from server import SqlDriftEnvironment + +This mirrors the flat layout both on disk and at import time; the +``sql_drift_env`` namespace exists only so the wheel has a canonical +name and so third parties can depend on a stable version string. +""" + +from __future__ import annotations + +__version__ = "0.1.0" + +__all__ = ["__version__"] diff --git a/_cli.py b/_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8f45fc94570407ccfedc5dbf9d2d5d08e77bd3 --- /dev/null +++ b/_cli.py @@ -0,0 +1,47 @@ +"""Console-script entrypoint that patches sys.path for the flat-import layout. + +Problem +------- +The project uses a *flat layout* (``pyproject.toml`` maps ``.`` → the +``sql_drift_env`` package directory). That means every sibling module +(``models``, ``actors``, ``engine`` …) is imported as a plain top-level +name rather than via the ``sql_drift_env.`` prefix. + +When the wheel is installed, those siblings land at +``site-packages/sql_drift_env/models.py`` etc., *not* at the top-level +``site-packages/`` directory. A naïve console-script that calls +``sql_drift_env.server.app:main`` would fail at ``from models import …`` +before reaching any application logic. + +Fix +--- +Insert the installed package directory (``site-packages/sql_drift_env/``) +onto ``sys.path`` *before* importing anything from the server package. +This mirrors what Docker achieves via ``--app-dir /app/env`` / ``PYTHONPATH``, +but works for any installed-wheel invocation without requiring a wrapper +script or Docker. +""" + +from __future__ import annotations + +import os +import sys + + +def main() -> None: + # __file__ resolves to site-packages/sql_drift_env/_cli.py after + # installation, so its parent IS the directory that contains models.py, + # server/, actors/, etc. + _pkg_dir = os.path.dirname(os.path.abspath(__file__)) + if _pkg_dir not in sys.path: + sys.path.insert(0, _pkg_dir) + + # Import lazily so the sys.path fix takes effect before any flat import + # in server/app.py or its transitive dependencies is attempted. + from server.app import main as _server_main + + _server_main() + + +if __name__ == "__main__": + main() diff --git a/actors/__init__.py b/actors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..803372d00e63f712f6a4988351dd3e1376bd52e5 --- /dev/null +++ b/actors/__init__.py @@ -0,0 +1 @@ +"""Non-agent actors (Engineering Manager, DBA Oracle) — P4/P9.""" diff --git a/actors/dba_oracle.py b/actors/dba_oracle.py new file mode 100644 index 0000000000000000000000000000000000000000..23787604cb8d8215e6426e787234572207eeeeee --- /dev/null +++ b/actors/dba_oracle.py @@ -0,0 +1,183 @@ +"""DBA Oracle: per-scenario escalating expert guidance. + +Feature-flagged (off by default). When ``enable_dba_oracle=True`` is set +at ``reset(...)`` or ``SQL_DRIFT_ENABLE_DBA_ORACLE=1`` is exported in the +environment, the ``consult_dba`` tool becomes available; three hint +tiers per scenario are shipped here, covering all 10 scenarios (6 static ++ 4 drift). + +Hints escalate in specificity: tier 1 names the likely failure mode and +the diagnostic to confirm it, tier 2 gives the rewrite invariant, and +tier 3 is a near-spoiler SQL skeleton. The rubric penalizes each consult +(consultation rubric), so the agent only wins by consulting if the hint net-reduces +downstream steps. +""" + +from __future__ import annotations + +import os +from typing import Final + +# --------------------------------------------------------------------------- +# Per-scenario 3-tier hint tables +# --------------------------------------------------------------------------- + + +_HINTS: Final[dict[str, tuple[str, str, str]]] = { + "01_correlated_subquery": ( + "[DBA tier 1] The expensive shape is a projection-time correlated subquery: " + "a COUNT over `orders` for every `users` row. Confirm by spotting " + "`SELECT COUNT(*) ... WHERE o.user_id = u.id` in the SELECT list or by " + "checking EXPLAIN for repeated dependent work. Preserve one output row per user.", + "[DBA tier 2] Aggregate fulfilled orders once by `user_id`, then LEFT JOIN that " + "small result to `users`. Keep the join outer and wrap the count with " + "`COALESCE(..., 0)` so users with no fulfilled orders stay in the result.", + "[DBA tier 3] Use `SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders " + "FROM users u LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders WHERE " + "status = 'fulfilled' GROUP BY user_id) c ON c.user_id = u.id ORDER BY u.id`. " + "Validate the row count equals the number of users.", + ), + "02_select_star_join": ( + "[DBA tier 1] The waste is over-projection: the inner three-way join uses " + "`SELECT *`, including wide product text and order metadata, while the outer " + "query keeps only `order_id`, product `name`, and `qty`.", + "[DBA tier 2] Inline the join and project exactly `oi.order_id`, `p.name`, and " + "`oi.qty`. Keep the products and orders joins plus the filters " + "`p.category = 'books'` and `oi.qty >= 2`; the wrapper exists only to hide " + "the star projection.", + "[DBA tier 3] Rewrite as `SELECT oi.order_id, p.name, oi.qty FROM order_items oi " + "JOIN products p ON p.id = oi.product_id JOIN orders o ON o.id = oi.order_id " + "WHERE p.category = 'books' AND oi.qty >= 2 ORDER BY oi.order_id, p.name`.", + ), + "03_cartesian_join": ( + "[DBA tier 1] This is an accidental cartesian join. `FROM events e, tenants t` " + "combined with `t.id = e.tenant_id + 0` prevents the optimizer from seeing a " + "clean tenant-key join early.", + "[DBA tier 2] Turn the comma join into an explicit equijoin on the tenant key. " + "Move only `t.id = e.tenant_id` into `ON`; keep the severity filter in `WHERE` " + "and preserve grouping by tenant tier.", + "[DBA tier 3] Use `SELECT t.tier, COUNT(*) AS n FROM events e JOIN tenants t " + "ON t.id = e.tenant_id WHERE e.severity IN ('error', 'critical') GROUP BY " + "t.tier ORDER BY t.tier`. Avoid arithmetic on the join key.", + ), + "04_distinct_groupby": ( + "[DBA tier 1] The duplicate-removal work is redundant. `GROUP BY session_id, path` " + "already emits one row per `(session_id, path)` pair, so a leading `DISTINCT` " + "adds a second deduplication pass over grouped rows.", + "[DBA tier 2] Do not introduce a CTE or change the aggregation grain. Remove only " + "`DISTINCT`; keep `COUNT(*) AS hits`, the same GROUP BY keys, and the same " + "ordering so row identity and sort order stay stable.", + "[DBA tier 3] The target shape is `SELECT session_id, path, COUNT(*) AS hits " + "FROM pageviews GROUP BY session_id, path ORDER BY session_id, path`. Validate " + "against the baseline result before comparing runtime.", + ), + "05_nested_subquery": ( + "[DBA tier 1] The nested `IN` clauses express a semi-join: authors who wrote " + "comments on published articles. The important identity is `comments.author_id`, " + "not `articles.author_id`.", + "[DBA tier 2] Flatten to `authors -> comments -> articles`, filter " + "`articles.status = 'published'`, and select distinct author display names. " + "`DISTINCT` is required here because one author can have many qualifying comments.", + "[DBA tier 3] Use `SELECT DISTINCT a.display_name FROM authors a JOIN comments c " + "ON c.author_id = a.id JOIN articles ar ON ar.id = c.article_id WHERE " + "ar.status = 'published' ORDER BY a.display_name`.", + ), + "06_having_as_where": ( + "[DBA tier 1] `status = 'fulfilled'` is a row-level predicate sitting in HAVING, " + "so the engine groups every status first and discards most groups afterward. " + "Only `SUM(amount_cents) >= 100000` truly belongs after aggregation.", + "[DBA tier 2] Move the status filter into `WHERE` before the GROUP BY. Keep " + "`status` in the projection and grouping to preserve the result shape, then " + "leave the aggregate threshold in HAVING.", + "[DBA tier 3] Use `SELECT tenant_id, status, SUM(amount_cents) AS total_cents " + "FROM orders WHERE status = 'fulfilled' GROUP BY tenant_id, status HAVING " + "SUM(amount_cents) >= 100000 ORDER BY tenant_id`.", + ), + "07_drift_column_rename": ( + "[DBA tier 1] If the old aggregation now fails with an unknown `user_id`, this is " + "schema drift rather than a performance issue. Read the changelog or describe " + "`orders`; `users.id` is unchanged, but the order-owner column moved.", + "[DBA tier 2] Replace every reference to `orders.user_id` with `orders.account_id` " + "in SELECT, GROUP BY, JOIN, and ORDER BY positions. Do not change the aggregate " + "logic; the rename preserves row semantics.", + "[DBA tier 3] Submit `SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) " + "AS total FROM orders GROUP BY account_id ORDER BY account_id`. Validate that " + "counts and totals match the pre-drift business result.", + ), + "08_drift_date_format": ( + "[DBA tier 1] The `events.ts` identifier still exists, but its type changed from " + "ISO text to BIGINT epoch milliseconds. A string date predicate can parse or " + "compare incorrectly; confirm with `describe_table('events')` and samples.", + "[DBA tier 2] Keep the same half-open UTC day window, but express both bounds as " + "epoch-ms integers. For 2026-04-21T00:00:00Z through the next midnight, use " + "`1776729600000 <= ts < 1776816000000`.", + "[DBA tier 3] Use `SELECT kind, COUNT(*) AS n FROM events WHERE ts >= " + "1776729600000 AND ts < 1776816000000 GROUP BY kind ORDER BY kind`. Do not quote " + "the bounds; they must be numeric comparisons against the BIGINT column.", + ), + "09_drift_enum_rule": ( + "[DBA tier 1] A formerly valid equality on `status = 'active'` now silently loses " + "rows because the business state was split into multiple stored labels. Sample " + "`tenants.status` before assuming the old lowercase value still exists.", + '[DBA tier 2] Preserve the business meaning "active tenants" by filtering on the ' + "union of replacement labels. Keep the same grouping by tier and ordering; only " + "the status predicate changes.", + "[DBA tier 3] Use `SELECT tier, COUNT(*) AS n FROM tenants WHERE status IN " + "('ACTIVE', 'ACTIVE_V2') GROUP BY tier ORDER BY tier`. Avoid `LOWER(status) = " + "'active'`; it misses `ACTIVE_V2`.", + ), + "10_drift_field_deprecation": ( + "[DBA tier 1] The inline `posts.author_name` column was normalized away. Describe " + "`posts` and list tables: you should see `posts.users_id` plus a new `users` " + "lookup carrying the human-readable name.", + "[DBA tier 2] Join `posts` to `users` through the new FK, group by `u.full_name`, " + "and alias it back to `author_name` so the result keeps the old report shape. " + "The post count still comes from `posts`.", + "[DBA tier 3] Use `SELECT u.full_name AS author_name, COUNT(*) AS n_posts FROM " + "posts p JOIN users u ON u.id = p.users_id GROUP BY u.full_name ORDER BY " + "u.full_name`.", + ), +} + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def is_enabled( + reset_flag: bool | None = None, *, env_var: str = "SQL_DRIFT_ENABLE_DBA_ORACLE" +) -> bool: + """Resolve the feature flag from (reset kwarg, env var, default-off).""" + if reset_flag is not None: + return bool(reset_flag) + raw = os.environ.get(env_var, "").strip().lower() + return raw in ("1", "true", "yes", "on") + + +def get_hint(scenario_id: str, tier: int) -> str: + """Return the hint for ``(scenario_id, tier)``; clamps tier to [1, 3]. + + Raises :class:`KeyError` on unknown scenario so tests can detect when + a new scenario was added without a hint table entry. + """ + if scenario_id not in _HINTS: + raise KeyError(f"no DBA hints for scenario_id={scenario_id!r}; known: {sorted(_HINTS)}") + tier = max(1, min(3, int(tier))) + return _HINTS[scenario_id][tier - 1] + + +def has_hints(scenario_id: str) -> bool: + return scenario_id in _HINTS + + +def known_scenarios() -> frozenset[str]: + return frozenset(_HINTS) + + +__all__ = [ + "get_hint", + "has_hints", + "is_enabled", + "known_scenarios", +] diff --git a/actors/engineering_manager.py b/actors/engineering_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fbcb3259a10afe64c33d7f6cabf9a24d3ac622 --- /dev/null +++ b/actors/engineering_manager.py @@ -0,0 +1,81 @@ +"""Engineering-manager narrator — authors deterministic drift runbooks. + +Consumed by the ``read_changelog`` tool. Output is deterministic per +``DriftConfig`` so tests can assert it character-for-character and the +agent can learn to parse drift kinds from the text. + +Entries are concise migration notes: what changed, why the old query +breaks, how to adapt, and what to validate before submission. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from scenarios.base import DriftConfig + + +def _sentence(prefix: str, body: str) -> str: + return f"[changelog] {prefix} {body}".strip() + + +def author_changelog(drift_config: DriftConfig) -> str: + kind = drift_config.kind + p = drift_config.payload + if kind == "column_rename": + return _sentence( + "Schema change -", + ( + f"`{p['table']}.{p['old']}` was renamed to `{p['table']}.{p['new']}` " + "as part of a naming normalization. Impact: queries that SELECT, JOIN, " + "GROUP BY, ORDER BY, or filter on the old identifier now fail at bind " + "time. Migration: use the new column everywhere the old one appeared; " + "related table primary keys are unchanged. Validate that counts and " + "aggregates are unchanged after the rename." + ), + ) + if kind == "date_format": + return _sentence( + "Format change -", + ( + f"`{p['table']}.{p['col']}` no longer stores ISO-8601 strings; it now " + "stores BIGINT milliseconds since the Unix epoch. Impact: quoted " + "timestamp literals and text comparisons no longer express the right " + "predicate. Migration: keep the same half-open time window, but compare " + "against numeric epoch-ms bounds. Validate with describe_table and a " + "small sample before submitting." + ), + ) + if kind == "enum_rule": + new_values = ", ".join(f"`{v}`" for v in p["new_values"]) + return _sentence( + "Business-rule change -", + ( + f"`{p['table']}.{p['col']}` value `{p['old_value']}` has been split " + f"into {new_values}. Impact: equality predicates on the old value " + "silently undercount after the deploy. Migration: replace the single " + "value predicate with an IN predicate over every replacement label. " + "Validate by sampling the enum distribution and preserving the existing " + "GROUP BY and projection." + ), + ) + if kind == "field_deprecation": + orig_t, orig_c = p["orig"] + lt, lid, lname = p["lookup"] + fk_col = f"{lt}_{lid}" + return _sentence( + "Deprecation -", + ( + f"`{orig_t}.{orig_c}` was deprecated and replaced by `{orig_t}.{fk_col}` " + f"pointing at `{lt}.{lid}`; the display value now lives on `{lt}.{lname}`. " + "Impact: projections or groups over the old inline string fail after " + "drift. Migration: join through the lookup table, project the display " + "column, and keep the old output alias if callers expect it. Validate " + "row counts after the join." + ), + ) + raise ValueError(f"unknown drift kind={kind!r}") + + +__all__ = ["author_changelog"] diff --git a/client.py b/client.py new file mode 100644 index 0000000000000000000000000000000000000000..4fedf411ddf54c00b5f919f1c892ba3d835cf344 --- /dev/null +++ b/client.py @@ -0,0 +1,121 @@ +"""SQLDrift ``EnvClient`` — tool-aware payload constructors + response parser. + +Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts, +notebook exploration, and integration tests all use the same WS-backed +session semantics. Stateful episodes MUST go through the ``/ws`` channel +(HTTP ``/step`` is stateless: one fresh env per request). + +Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.) +hide the discriminated-union boilerplate so agent code reads naturally:: + + env = SqlDriftEnv(base_url="http://localhost:8000").sync() + with env: + r = env.reset(seed=42, scenario_id="03_cartesian_join") + r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events")) + ... +""" + +from __future__ import annotations + +from typing import Any + +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient + +from models import ( + ConsultDBAPayload, + DescribeTablePayload, + ExplainQueryPayload, + ListTablesPayload, + ReadChangelogPayload, + RunQueryPayload, + SampleRowsPayload, + SqlDriftAction, + SqlDriftObservation, + SqlDriftState, + SubmitRewritePayload, + ToolName, +) + + +class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]): + """Tool-aware client for the SQLDrift OpenEnv environment.""" + + # ------------------------------------------------------------------ + # EnvClient ABC implementations + # ------------------------------------------------------------------ + + def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]: + return action.model_dump(mode="json") + + def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]: + obs_data = payload.get("observation", {}) + observation = SqlDriftObservation.model_validate(obs_data) + # Base transport strips reward + done off the observation dict — we + # re-populate them so the agent can read straight off `.observation`. + reward = payload.get("reward") + done = bool(payload.get("done", False)) + observation.reward = reward + observation.done = done + return StepResult(observation=observation, reward=reward, done=done) + + def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState: + return SqlDriftState.model_validate(payload) + + # ------------------------------------------------------------------ + # Action factories — one per tool, accepting only the args that tool + # cares about; payload.kind is filled in automatically. + # ------------------------------------------------------------------ + + @staticmethod + def action_list_tables() -> SqlDriftAction: + return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload()) + + @staticmethod + def action_describe_table(table: str) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.DESCRIBE_TABLE, + payload=DescribeTablePayload(table=table), + ) + + @staticmethod + def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.SAMPLE_ROWS, + payload=SampleRowsPayload(table=table, limit=limit), + ) + + @staticmethod + def action_run_query(sql: str) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.RUN_QUERY, + payload=RunQueryPayload(sql=sql), + ) + + @staticmethod + def action_explain_query(sql: str) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.EXPLAIN_QUERY, + payload=ExplainQueryPayload(sql=sql), + ) + + @staticmethod + def action_read_changelog() -> SqlDriftAction: + return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload()) + + @staticmethod + def action_submit_rewrite(sql: str) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.SUBMIT_REWRITE, + payload=SubmitRewritePayload(sql=sql), + ) + + @staticmethod + def action_consult_dba(question: str) -> SqlDriftAction: + return SqlDriftAction( + tool=ToolName.CONSULT_DBA, + payload=ConsultDBAPayload(question=question), + ) + + +__all__ = ["SqlDriftEnv"] diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efb1657eac544524bbe0e9217fcb277b59f5f2ba --- /dev/null +++ b/engine/__init__.py @@ -0,0 +1 @@ +"""Deterministic verifier / profiler / drift / reward engine.""" diff --git a/engine/drift.py b/engine/drift.py new file mode 100644 index 0000000000000000000000000000000000000000..164617c57d27b9b3d6eaf03da19ed870684ac675 --- /dev/null +++ b/engine/drift.py @@ -0,0 +1,242 @@ +"""Drift engine: four atomic, idempotent DDL operations. + +Each apply_* function mutates ``conn`` in place inside a DuckDB +``BEGIN; ... COMMIT`` pair and returns a machine-readable changelog +string. Humans consume the string via the :class:`read_changelog` +tool; the rubric consults a separate drift-acknowledgement flag on the +runtime state, not the string itself. + +Idempotency is enforced via a post-condition schema probe: once the drift +has been applied (the target column / enum value is in the expected +post-state), a second call short-circuits with the same changelog string. +This matters because the environment's drift-trigger check runs every +step and needs to be safe to retry. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import duckdb + + +# DuckDB auto-commits DDL and forbids mixing multi-statement transactions +# with schema alterations across commit boundaries. Each drift operation +# therefore executes its statements sequentially on the default +# auto-commit connection; individual DML statements (UPDATEs) are +# internally atomic at the statement level, which is sufficient for the +# fixture mutation the env needs. If a drift operation raises mid-way we +# tear down and re-seed the DuckDB via ScenarioSpec.materialize — there's +# no long-lived on-disk state to roll back. + + +def _table_columns(conn: duckdb.DuckDBPyConnection, table: str) -> list[str]: + rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall() + # PRAGMA table_info returns (cid, name, type, notnull, dflt_value, pk) + return [r[1] for r in rows] + + +def _table_exists(conn: duckdb.DuckDBPyConnection, table: str) -> bool: + rows = conn.execute( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", [table] + ).fetchone() + return bool(rows and rows[0]) + + +# ============================================================================= +# Column rename +# ============================================================================= + + +def apply_column_rename(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str: + """``{"table": str, "old": str, "new": str}``.""" + table = payload["table"] + old = payload["old"] + new = payload["new"] + + cols = _table_columns(conn, table) + if new in cols and old not in cols: + return f"rename_already_applied:{table}.{old}->{new}" + if old not in cols: + raise ValueError(f"column_rename: {table}.{old} missing (cols={cols})") + + conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{old}" TO "{new}"') + return f"rename:{table}.{old}->{new}" + + +# ============================================================================= +# Date format change (iso_string → epoch_ms) +# ============================================================================= + + +def apply_date_format_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str: + """``{"table": str, "col": str, "from": "iso_string", "to": "epoch_ms"}``. + + Only the one direction is supported for now; the payload still carries + from/to for forward-compatibility and audit. + """ + table = payload["table"] + col = payload["col"] + from_fmt = payload.get("from", "iso_string") + to_fmt = payload.get("to", "epoch_ms") + if (from_fmt, to_fmt) != ("iso_string", "epoch_ms"): + raise NotImplementedError( + f"date_format_change only supports iso_string→epoch_ms, got {from_fmt}→{to_fmt}" + ) + + cols = _table_columns(conn, table) + # Idempotent: once column is BIGINT, consider it applied. + type_row = conn.execute( + "SELECT data_type FROM information_schema.columns WHERE table_name = ? AND column_name = ?", + [table, col], + ).fetchone() + if type_row is None: + raise ValueError(f"date_format_change: {table}.{col} missing (cols={cols})") + if "BIGINT" in type_row[0].upper() or "INT" in type_row[0].upper(): + return f"date_format_already_applied:{table}.{col}" + + tmp = f"{col}_epoch_ms" + conn.execute(f'ALTER TABLE "{table}" ADD COLUMN "{tmp}" BIGINT') + conn.execute( + f'UPDATE "{table}" SET "{tmp}" = ' + f'CAST(EXTRACT(EPOCH FROM CAST("{col}" AS TIMESTAMP)) * 1000 AS BIGINT)' + ) + conn.execute(f'ALTER TABLE "{table}" DROP COLUMN "{col}"') + conn.execute(f'ALTER TABLE "{table}" RENAME COLUMN "{tmp}" TO "{col}"') + + return f"date_format:{table}.{col}:iso_string->epoch_ms" + + +# ============================================================================= +# Enum rule change (split `old_value` into N new values) +# ============================================================================= + + +def apply_enum_rule_change(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str: + """``{"table": str, "col": str, "old_value": str, "new_values": list[str]}``. + + Rows holding ``old_value`` are re-distributed deterministically into + ``new_values`` (round-robin by rowid) so the split is reproducible. + """ + table = payload["table"] + col = payload["col"] + old_value = payload["old_value"] + new_values: list[str] = list(payload["new_values"]) + if not new_values: + raise ValueError("enum_rule_change: new_values must be non-empty") + + count_row = conn.execute( + f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" = ?', [old_value] + ).fetchone() + count_old = count_row[0] if count_row is not None else 0 + # Idempotent: if old_value has already been drained AND any of the + # new_values is present, treat as applied. + if count_old == 0: + has_new_row = conn.execute( + f'SELECT COUNT(*) FROM "{table}" WHERE "{col}" IN ({",".join("?" * len(new_values))})', + new_values, + ).fetchone() + has_new = has_new_row[0] if has_new_row is not None else 0 + if has_new > 0: + return f"enum_rule_already_applied:{table}.{col}:{old_value}->{new_values}" + + # Deterministic split by rowid mod N. + case_branches = " ".join( + f"WHEN mod(rid, {len(new_values)}) = {i} THEN '{v}'" for i, v in enumerate(new_values) + ) + conn.execute( + f"CREATE TEMP TABLE _enum_remap AS " + f"SELECT rowid AS rid, " + f"CASE {case_branches} END AS new_val " + f'FROM "{table}" WHERE "{col}" = ?', + [old_value], + ) + conn.execute( + f'UPDATE "{table}" SET "{col}" = _enum_remap.new_val ' + f'FROM _enum_remap WHERE _enum_remap.rid = "{table}".rowid' + ) + conn.execute("DROP TABLE _enum_remap") + + return f"enum_rule:{table}.{col}:{old_value}->{'+'.join(new_values)}" + + +# ============================================================================= +# Field deprecation (replace inline string col with FK lookup) +# ============================================================================= + + +def apply_field_deprecation(conn: duckdb.DuckDBPyConnection, payload: dict[str, Any]) -> str: + """``{"orig": (table, col), "lookup": (table, id_col, name_col)}``. + + - Creates the lookup table (if missing) and seeds it with distinct values + observed on ``orig.col``. + - Adds ``orig.`` with a FK-style backfill. + - Drops ``orig.col``. + """ + orig_table, orig_col = payload["orig"] + lookup_table, lookup_id_col, lookup_name_col = payload["lookup"] + new_fk_col = f"{lookup_table}_{lookup_id_col}" # e.g. "users_id" + + orig_cols = _table_columns(conn, orig_table) + if orig_col not in orig_cols and new_fk_col in orig_cols: + return f"field_deprecation_already_applied:{orig_table}.{orig_col}" + if orig_col not in orig_cols: + raise ValueError(f"field_deprecation: {orig_table}.{orig_col} missing (cols={orig_cols})") + + if not _table_exists(conn, lookup_table): + conn.execute( + f'CREATE TABLE "{lookup_table}" (' + f' "{lookup_id_col}" BIGINT PRIMARY KEY,' + f' "{lookup_name_col}" VARCHAR' + ");" + ) + conn.execute( + f'INSERT INTO "{lookup_table}" ("{lookup_id_col}", "{lookup_name_col}") ' + f"SELECT ROW_NUMBER() OVER (ORDER BY v) + " + f'COALESCE((SELECT MAX("{lookup_id_col}") FROM "{lookup_table}"), 0), v ' + f'FROM (SELECT DISTINCT "{orig_col}" AS v FROM "{orig_table}") ' + f"WHERE v IS NOT NULL " + f' AND v NOT IN (SELECT "{lookup_name_col}" FROM "{lookup_table}");' + ) + conn.execute(f'ALTER TABLE "{orig_table}" ADD COLUMN "{new_fk_col}" BIGINT') + conn.execute( + f'UPDATE "{orig_table}" SET "{new_fk_col}" = lookup."{lookup_id_col}" ' + f'FROM "{lookup_table}" lookup ' + f'WHERE lookup."{lookup_name_col}" = "{orig_table}"."{orig_col}"' + ) + conn.execute(f'ALTER TABLE "{orig_table}" DROP COLUMN "{orig_col}"') + + return ( + f"field_deprecation:{orig_table}.{orig_col}->" + f"{orig_table}.{new_fk_col}→{lookup_table}.{lookup_name_col}" + ) + + +# ============================================================================= +# Dispatcher +# ============================================================================= + + +DRIFT_HANDLERS = { + "column_rename": apply_column_rename, + "date_format": apply_date_format_change, + "enum_rule": apply_enum_rule_change, + "field_deprecation": apply_field_deprecation, +} + + +def apply_drift(conn: duckdb.DuckDBPyConnection, kind: str, payload: dict[str, Any]) -> str: + if kind not in DRIFT_HANDLERS: + raise ValueError(f"unknown drift kind={kind!r}; known: {sorted(DRIFT_HANDLERS)}") + return DRIFT_HANDLERS[kind](conn, payload) + + +__all__ = [ + "DRIFT_HANDLERS", + "apply_column_rename", + "apply_date_format_change", + "apply_drift", + "apply_enum_rule_change", + "apply_field_deprecation", +] diff --git a/engine/profiler.py b/engine/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..14c1eb27ac80513ae839504ccf7737c7d5ad54aa --- /dev/null +++ b/engine/profiler.py @@ -0,0 +1,298 @@ +"""Query profiling utilities. + +A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer. + +* :func:`execute_once_timed` runs a statement exactly once, enforcing a + hard ``timeout_s`` wall-clock budget. It is the single entry point used + by the env for agent-provided SQL so the documented query timeout + cannot be bypassed. An optional ``max_rows`` caps result-set + materialization — the fetch is aborted as soon as more than + ``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot + drive the server OOM before the caller's size check runs. +* :func:`execute_hash_timed` executes a statement once and hashes its full + result incrementally via ``fetchmany`` so correctness checks do not have + to materialize the full row set in Python memory. +* :func:`median_of_3_warm_ms` performs one untimed warm-up then three + timed runs and returns the median milliseconds. Used by scenario + materialization to publish a stable baseline runtime. + +Both helpers raise :class:`TimeoutError` when a single run exceeds the +budget; ``duckdb.Error`` propagates unchanged to the caller. +""" + +from __future__ import annotations + +import contextlib +import os +import threading +import time +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from engine.verifier import canonical_row_hash +from utilities.logger import get_module_logger + +if TYPE_CHECKING: + import duckdb + + +DEFAULT_TIMEOUT_S: float = 2.0 +INTERRUPT_GRACE_S: float = 0.25 +# Maximum number of watchdog escalations (leaked threads) tolerated before +# logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS. +MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3")) + +_LOG = get_module_logger(__name__) +_FETCH_CHUNK_ROWS = 1024 + +# Module-level counter — incremented each time a watchdog thread survives +# interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe +# via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count(). +_watchdog_leak_lock: threading.Lock = threading.Lock() +_watchdog_leaked_count: int = 0 + + +def get_watchdog_leak_count() -> int: + """Return the cumulative number of watchdog threads that survived interrupt. + + A non-zero value means at least one DuckDB worker thread was not stopped + cleanly and is still alive in the background. Production monitoring should + alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`. + """ + return _watchdog_leaked_count + + +class QueryWatchdogEscalationError(RuntimeError): + """DuckDB worker survived interrupt; the connection is no longer safe.""" + + +@dataclass(frozen=True) +class TimedResult: + """Output of :func:`execute_once_timed`. + + ``columns`` preserves DuckDB's cursor ``description`` order so callers + can emit a :class:`models.RunQueryResult` without re-executing the + query just to recover column names. + + ``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap + and the query produced strictly more rows than that cap; in that + case ``rows`` contains exactly ``max_rows + 1`` entries (the + one-over read that proves overflow). Callers that care about size + limits must branch on ``truncated`` rather than re-checking + ``len(rows)`` against their cap. + """ + + columns: list[str] + rows: list[tuple[Any, ...]] + elapsed_ms: float + truncated: bool = False + + +def _fetch_capped( + cursor: duckdb.DuckDBPyConnection, + max_rows: int, +) -> tuple[list[tuple[Any, ...]], bool]: + """Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany. + + Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the + cursor still has unread rows — we stopped on the first row past the + cap so the caller can signal overflow without materialising the + rest of a potentially enormous result set. + """ + # chunk=1024 trades a few extra Python calls for not over-fetching + # by orders of magnitude when results are modest. The +1 in the + # final budget is what makes overflow detectable. + rows: list[tuple[Any, ...]] = [] + budget = max_rows + 1 + while budget > 0: + batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget)) + if not batch: + return rows, False + rows.extend(batch) + budget -= len(batch) + return rows, len(rows) > max_rows + + +def _iter_cursor_rows( + cursor: duckdb.DuckDBPyConnection, +) -> Iterator[tuple[Any, ...]]: + while True: + batch = cursor.fetchmany(_FETCH_CHUNK_ROWS) + if not batch: + return + yield from batch + + +def _run_worker_with_watchdog[T]( + conn: duckdb.DuckDBPyConnection, + sql: str, + timeout_s: float, + worker: Callable[[], T], +) -> T: + result_holder: dict[str, object] = {} + + def runner() -> None: + try: + result_holder["result"] = worker() + except BaseException as exc: # Must forward all failures from the worker thread. + result_holder["error"] = exc + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join(timeout_s) + if thread.is_alive(): + # DuckDB's interrupt API is connection-scoped and thread-safe; + # we ask the query to unwind and then wait *unconditionally* + # for the worker to exit before surfacing the timeout to the + # caller. If we returned while the thread were still alive, it + # would retain access to ``conn`` and its result could race + # future queries on the same connection — a previously + # observed source of flaky post-timeout behaviour. In practice + # DuckDB's interrupt releases the worker within a handful of + # milliseconds; if the engine ever fails to honour interrupt + # the process will hang here, which is the correct failure + # mode for a connection whose state is no longer safe to + # reuse. + with contextlib.suppress(Exception): + conn.interrupt() + thread.join(INTERRUPT_GRACE_S) + if thread.is_alive(): + global _watchdog_leaked_count + with _watchdog_leak_lock: + _watchdog_leaked_count += 1 + leak_count = _watchdog_leaked_count + log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error + log_fn( + "query watchdog failed to stop worker after %.3fs timeout + %.3fs grace" + " (cumulative leaked threads: %d)", + timeout_s, + INTERRUPT_GRACE_S, + leak_count, + ) + raise QueryWatchdogEscalationError( + f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}" + ) + raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}") + if "error" in result_holder: + error = result_holder["error"] + assert isinstance(error, BaseException) + raise error + return cast(T, result_holder["result"]) + + +def _run_with_watchdog( + conn: duckdb.DuckDBPyConnection, + sql: str, + timeout_s: float, + max_rows: int | None, +) -> TimedResult: + def worker() -> TimedResult: + start = time.perf_counter_ns() + cursor = conn.execute(sql) + columns = [d[0] for d in cursor.description] if cursor.description else [] + if max_rows is None: + rows = cursor.fetchall() + truncated = False + else: + rows, truncated = _fetch_capped(cursor, max_rows) + elapsed_ns = time.perf_counter_ns() - start + return TimedResult( + columns=columns, + rows=rows, + elapsed_ms=elapsed_ns / 1_000_000.0, + truncated=truncated, + ) + + result = _run_worker_with_watchdog(conn, sql, timeout_s, worker) + assert isinstance(result, TimedResult) + return result + + +def execute_once_timed( + conn: duckdb.DuckDBPyConnection, + sql: str, + *, + timeout_s: float = DEFAULT_TIMEOUT_S, + max_rows: int | None = None, +) -> tuple[list[tuple[Any, ...]], float]: + """Single timed execution — returns ``(rows, elapsed_ms)``. + + Thin wrapper for callers that don't need column metadata or the + truncation flag. + """ + res = _run_with_watchdog(conn, sql, timeout_s, max_rows) + return res.rows, res.elapsed_ms + + +def execute_once_with_columns( + conn: duckdb.DuckDBPyConnection, + sql: str, + *, + timeout_s: float = DEFAULT_TIMEOUT_S, + max_rows: int | None = None, +) -> TimedResult: + """Single timed execution — returns columns + rows + elapsed_ms. + + When ``max_rows`` is supplied, the fetch aborts at the first row + past the cap and ``TimedResult.truncated`` is set. The elapsed + milliseconds in that case reflect the partial scan, not the query's + would-be completion time — a truncated read is a *hard error* in + agent-facing code paths, not a performance measurement. + """ + return _run_with_watchdog(conn, sql, timeout_s, max_rows) + + +def execute_hash_timed( + conn: duckdb.DuckDBPyConnection, + sql: str, + *, + timeout_s: float = DEFAULT_TIMEOUT_S, +) -> tuple[str, float]: + """Single timed execution — returns ``(result_hash, elapsed_ms)``. + + Unlike :func:`execute_once_timed`, this drains the cursor via + ``fetchmany`` and hashes rows incrementally, so callers can compare a + large final result to ground truth without materializing the full row + set in Python memory. + """ + + def worker() -> tuple[str, float]: + start = time.perf_counter_ns() + cursor = conn.execute(sql) + result_hash = canonical_row_hash(_iter_cursor_rows(cursor)) + elapsed_ns = time.perf_counter_ns() - start + return result_hash, elapsed_ns / 1_000_000.0 + + result = _run_worker_with_watchdog(conn, sql, timeout_s, worker) + result_hash, elapsed_ms = result + assert isinstance(result_hash, str) + assert isinstance(elapsed_ms, float) + return result_hash, elapsed_ms + + +def median_of_3_warm_ms( + conn: duckdb.DuckDBPyConnection, + sql: str, + *, + timeout_s: float = DEFAULT_TIMEOUT_S, +) -> float: + """Warm cache, then median-of-3 timed runs. Returns milliseconds.""" + _run_with_watchdog(conn, sql, timeout_s, None) + timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)] + timings.sort() + return timings[1] + + +__all__ = [ + "DEFAULT_TIMEOUT_S", + "INTERRUPT_GRACE_S", + "MAX_LEAKED_WATCHDOG_THREADS", + "QueryWatchdogEscalationError", + "TimedResult", + "execute_hash_timed", + "execute_once_timed", + "execute_once_with_columns", + "get_watchdog_leak_count", + "median_of_3_warm_ms", +] diff --git a/engine/reward.py b/engine/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..23825a0d45ca01de12e26281c74b6d4c502beaed --- /dev/null +++ b/engine/reward.py @@ -0,0 +1,437 @@ +"""SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax). + +Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`): + + r_correct correctness vs ground-truth hash, gated on ≥ 1.2× speedup + r_drift bonus/penalty for (not) adapting to post-drift identifiers + r_speedup tanh-shaped speedup bonus, gated on r_correct > 0 + r_step_tax base step tax plus bounded productive-action rebates + r_gatekeepers escalating tool-error / repeat-failing / no-op penalties + r_consult_dba DBA-oracle consult penalties (feature-flagged; 0 when off) + +All child rubrics share a single ``ctx_provider`` that returns the private +:class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric +stateless relative to the environment and makes each component +individually unit-testable with a synthesized triple +``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``. +""" + +from __future__ import annotations + +import math +import re +from collections.abc import Callable +from typing import TYPE_CHECKING + +from openenv.core.rubrics import Rubric + +from models import ( + SqlDriftAction, + SqlDriftObservation, + SubmitRewriteResult, + ToolError, + ToolName, +) + +if TYPE_CHECKING: + from engine.runtime import RuntimeEpisodeState + + +# Tunable thresholds pulled out to module level so tests and future +# curriculum code share a single source of truth. +SPEEDUP_MIN: float = 1.2 +SPEEDUP_CAP_FOR_INFTY: float = 64.0 +STEP_TAX: float = -0.03 +STEP_REBATE_LIST_TABLES: float = 0.04 +STEP_REBATE_DESCRIBE_TABLE: float = 0.06 +STEP_REBATE_SAMPLE_ROWS: float = 0.05 +STEP_REBATE_RUN_QUERY: float = 0.04 +STEP_REBATE_EXPLAIN_QUERY: float = 0.04 +STEP_REBATE_READ_CHANGELOG: float = 0.08 + +GATE_MALFORMED_TOOL_CALL: float = -0.3 +GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1 +GATE_REPEAT_FAILING_QUERY: float = -0.1 +GATE_BASELINE_VERBATIM: float = -0.2 +_MAX_ESCALATION_STEPS: int = 3 + +CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8) + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def canonicalize_sql(sql: str) -> str: + """Whitespace/case/alias-insensitive canonical form. + + Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree; + falls back to a simple whitespace fold if sqlglot rejects the SQL + (e.g. during the baseline-verbatim check on an agent-submitted blob). + """ + try: + import sqlglot + + expr = sqlglot.parse_one(sql, dialect="duckdb") + return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower() + except Exception: + return " ".join(sql.lower().split()) + + +_IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b") + + +def _extract_identifiers(sql: str) -> frozenset[str]: + """Loose case-preserving identifier harvest. + + Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'` + does not leak 'ACTIVE' into the identifier set. + """ + stripped = re.sub(r"'[^']*'", "", sql) + return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped)) + + +def _extract_column_references(sql: str) -> frozenset[str]: + """Column-reference identifiers via sqlglot AST (excludes alias labels). + + For drift-adapt scoring, ``SELECT account_id AS user_id`` references + the new column and merely labels the output — the alias must not + count as a surviving pre-drift marker. Falls back to the regex + extractor on parse failure so malformed SQL still scores something. + """ + try: + import sqlglot + expr = sqlglot.parse_one(sql, dialect="duckdb") + except Exception: + return _extract_identifiers(sql) + if expr is None: + return _extract_identifiers(sql) + return frozenset( + n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name + ) + + +def _literals(sql: str) -> frozenset[str]: + """All `'..'`-quoted string literals in `sql`.""" + return frozenset(re.findall(r"'([^']*)'", sql)) + + +_AGENT_MS_EPSILON: float = 1e-6 + + +def effective_speedup(rt: RuntimeEpisodeState) -> float | None: + """Compute ``effective_speedup`` from the runtime snapshot (speedup rubric). + + Single source of truth for the speedup number used across the code + base — rubric scoring, the skill library's ``avg_speedup`` field, + and the training evaluator all route through here so divergent + definitions cannot drift apart. + + Returns: + + * ``None`` — no submission has happened yet. Callers that need a + numeric default (e.g. the rubric, which is only invoked + post-submission) should verify ``rt.submitted`` first. + * ``+∞`` — drift has fired and the pre-drift baseline SQL no longer + executes against the post-drift schema; any correct submission is + definitionally "infinitely faster" than an unrunnable baseline. + * ``baseline_ms / max(agent_ms, ε)`` otherwise. A tiny ``ε`` clamp + guards against zero/negative timings from sub-microsecond queries + and treats them as "as fast as possible" (very large, finite + speedup) rather than silently collapsing the reward. + """ + if rt.submitted_runtime_ms is None: + return None + if rt.drift_fired and rt.baseline_postdrift_raises: + return math.inf + agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON) + return rt.baseline_runtime_ms / agent_ms + + +def _speedup_for_reward(rt: RuntimeEpisodeState) -> float: + """Rubric-facing speedup that never returns ``None``. + + The rubric is only invoked once ``rt.submitted`` is True, so + :func:`effective_speedup` cannot return ``None`` from these call + sites; we assert that and coerce to ``0.0`` defensively if it ever + does (prevents a silent ``TypeError`` inside the reward math). + """ + val = effective_speedup(rt) + return 0.0 if val is None else val + + +def _is_terminal_submission( + action: SqlDriftAction, + observation: SqlDriftObservation, + rt: RuntimeEpisodeState, +) -> bool: + """True iff this step is the submission step. + + The env sets ``done=True`` on a successful submission and attaches a + :class:`SubmitRewriteResult`; we gate terminal rewards on both + signals so repeated rubric calls on an unchanged state don't + double-score. + """ + if not rt.submitted: + return False + if action.tool != ToolName.SUBMIT_REWRITE: + return False + tr = observation.tool_result + return isinstance(tr, SubmitRewriteResult) + + +def _gt_hash(rt: RuntimeEpisodeState) -> str | None: + if rt.drift_fired and rt.gt_result_hash_postdrift is not None: + return rt.gt_result_hash_postdrift + return rt.gt_result_hash_predrift + + +# ============================================================================= +# Child rubrics +# ============================================================================= + + +class _CtxChild(Rubric): + """Base child rubric sharing the ctx provider.""" + + def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None: + super().__init__() + object.__setattr__(self, "_ctx", ctx_provider) + + def forward( + self, + action: SqlDriftAction, + observation: SqlDriftObservation, + ) -> float: + raise NotImplementedError + + +class Correctness(_CtxChild): + """Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup.""" + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + if not _is_terminal_submission(action, observation, rt): + return 0.0 + gt = _gt_hash(rt) + agent_hash = rt.submitted_result_hash + if gt is None or agent_hash is None: + return 0.0 + if agent_hash != gt: + return -1.0 + speedup = _speedup_for_reward(rt) + if speedup >= SPEEDUP_MIN: + return 1.0 + return 0.5 + + +class DriftAdapt(_CtxChild): + """+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only + submission after drift fired. + + Adaptation is detected against two scenario-declared identifier sets: + + * ``postdrift_identifiers`` — identifiers/literals that only a + correct post-drift rewrite will introduce (e.g. ``account_id`` + after a column rename, ``'ACTIVE'`` after an enum split). + * ``predrift_identifiers`` — identifiers/literals a submission that + ignored the drift would retain (e.g. ``user_id``, ``'active'``, + the ISO anchor strings under date-format drift). + + A submission is considered "adapted" when it either surfaces a + post-drift marker *or* the scenario declares no distinctive + post-drift identifiers (e.g. date-format drift keeps the same + column name and only the literal shape changes) AND it does not + retain any pre-drift marker. The penalty fires only when the + submission still carries pre-drift markers AND produced the wrong + post-drift result — so a merely partial rewrite (neither pre- + nor post-flavoured) never earns a penalty it can't diagnose. + """ + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + # Only drift scenarios participate. + if rt.gt_result_hash_postdrift is None and not rt.drift_fired: + return 0.0 + if not _is_terminal_submission(action, observation, rt): + return 0.0 + inst = getattr(rt, "instance", None) + post_ids: frozenset[str] = ( + getattr(inst, "postdrift_identifiers", frozenset()) or frozenset() + ) + pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset() + agent_sql = rt.submitted_sql or "" + + idents = _extract_column_references(agent_sql) + literals = _literals(agent_sql) + markers = idents | literals + uses_post = bool(post_ids & markers) + uses_pre = bool(pre_ids & markers) + + # Treat "no distinctive post identifier" scenarios as + # satisfied by absence-of-pre (see class docstring). + adapted = (uses_post or not post_ids) and not uses_pre + + agent_hash = rt.submitted_result_hash + gt_post = rt.gt_result_hash_postdrift + + if rt.drift_fired and agent_hash == gt_post and adapted: + return 0.5 + if rt.drift_fired and uses_pre and agent_hash != gt_post: + return -0.5 + return 0.0 + + +class Speedup(_CtxChild): + """Terminal-only, gated on r_correct > 0: 0.3·tanh(log2(speedup)/3).""" + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + if not _is_terminal_submission(action, observation, rt): + return 0.0 + gt = _gt_hash(rt) + if gt is None or rt.submitted_result_hash != gt: + return 0.0 + raw = _speedup_for_reward(rt) + if math.isinf(raw): + raw = SPEEDUP_CAP_FOR_INFTY + if raw <= 1.0: + return 0.0 + return 0.3 * math.tanh(math.log2(raw) / 3.0) + + +class StepTax(_CtxChild): + """Base step tax plus bounded rebates for productive exploration.""" + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + if _is_terminal_submission(action, observation, rt): + return 0.0 + rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0))) + return STEP_TAX + rebate + + +class Gatekeepers(_CtxChild): + """Sum of three independent penalties; repeats escalate up to a cap.""" + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + penalty = 0.0 + # 1. Malformed / failed tool call — ToolError emitted this step. + if isinstance(observation.tool_result, ToolError): + penalty += GATE_MALFORMED_TOOL_CALL + streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1) + penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS) + # 2. Repeat failing query — env marks the flag on the runtime + # state immediately before invoking the rubric. + repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1) + if repeats > 0: + penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS) + # 3. Baseline-verbatim submission (Rev-3 gate — stacks with + # correctness's +0.5 partial to cap the no-op rewrite at +0.3). + if ( + action.tool == ToolName.SUBMIT_REWRITE + and _is_terminal_submission(action, observation, rt) + and rt.submitted_sql_canonical == rt.baseline_sql_canonical + ): + penalty += GATE_BASELINE_VERBATIM + return penalty + + +class ConsultDBA(_CtxChild): + """Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on.""" + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + rt = self._ctx() + oracle_enabled = getattr(rt, "dba_oracle_enabled", False) + if not oracle_enabled: + return 0.0 + if action.tool != ToolName.CONSULT_DBA: + return 0.0 + # Count the consult THIS step by indexing into the escalation + # table using the pre-increment value (env increments on the same step). + tier = min(rt.consultations_used, len(CONSULT_ESCALATION)) + if tier <= 0: + return CONSULT_ESCALATION[0] + return CONSULT_ESCALATION[tier - 1] + + +# ============================================================================= +# Composite +# ============================================================================= + + +class SqlDriftRubric(Rubric): + """Composite rubric: sum of six children. + + Registration as attributes auto-enrolls them in + :meth:`Rubric.named_rubrics` so training loops can introspect + per-component scores. + """ + + def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None: + super().__init__() + # NOTE: order matters — correctness must populate last_score before + # speedup reads it via the shared ctx_provider (both are pure + # functions of the runtime state, so identical output — but the + # explicit ordering documents the intent). + self.correctness = Correctness(ctx_provider) + self.drift_adapt = DriftAdapt(ctx_provider) + self.speedup = Speedup(ctx_provider) + self.step_tax = StepTax(ctx_provider) + self.gatekeepers = Gatekeepers(ctx_provider) + self.consult_dba = ConsultDBA(ctx_provider) + + def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: + total = ( + self.correctness(action, observation) + + self.drift_adapt(action, observation) + + self.speedup(action, observation) + + self.step_tax(action, observation) + + self.gatekeepers(action, observation) + + self.consult_dba(action, observation) + ) + return total + + def component_scores(self) -> dict[str, float]: + """Return the most-recent per-component scores, keyed for W&B. + + Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation + envelope and the demo plots agree on a stable schema. + """ + return { + "r_correct": float(self.correctness.last_score or 0.0), + "r_drift": float(self.drift_adapt.last_score or 0.0), + "r_speedup": float(self.speedup.last_score or 0.0), + "r_step_tax": float(self.step_tax.last_score or 0.0), + "r_gatekeepers": float(self.gatekeepers.last_score or 0.0), + "r_consult_dba": float(self.consult_dba.last_score or 0.0), + } + + +__all__ = [ + "CONSULT_ESCALATION", + "ConsultDBA", + "Correctness", + "DriftAdapt", + "GATE_BASELINE_VERBATIM", + "GATE_CONSECUTIVE_TOOL_ERROR", + "GATE_MALFORMED_TOOL_CALL", + "GATE_REPEAT_FAILING_QUERY", + "Gatekeepers", + "SPEEDUP_CAP_FOR_INFTY", + "SPEEDUP_MIN", + "STEP_REBATE_DESCRIBE_TABLE", + "STEP_REBATE_EXPLAIN_QUERY", + "STEP_REBATE_LIST_TABLES", + "STEP_REBATE_READ_CHANGELOG", + "STEP_REBATE_RUN_QUERY", + "STEP_REBATE_SAMPLE_ROWS", + "STEP_TAX", + "Speedup", + "SqlDriftRubric", + "StepTax", + "canonicalize_sql", + "effective_speedup", +] diff --git a/engine/runtime.py b/engine/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..983e70127eff749ad15236e6eb7356d5d632b2b3 --- /dev/null +++ b/engine/runtime.py @@ -0,0 +1,101 @@ +"""Private per-episode runtime state (not exposed over the wire). + +This module is imported by both :mod:`engine.reward` and +:mod:`server.sql_drift_env_environment` — keeping it out of ``server/`` +avoids the import cycle ``engine.reward → server → engine.reward``. + +NEVER serialize or expose this over any endpoint. The public state +projection lives in :class:`models.SqlDriftState`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from models import EpisodePhase + +if TYPE_CHECKING: + import duckdb + + +@dataclass +class RuntimeEpisodeState: + """Private per-episode state — holds secrets (DB handle, ground truth).""" + + episode_id: str + seed: int + scenario_id: str + + instance: Any # scenarios.base.ScenarioInstance — forward ref avoids import cycle + conn: duckdb.DuckDBPyConnection + + # Ground truth hashes (never exposed) + gt_result_hash_predrift: str + gt_result_hash_postdrift: str | None + + # Baseline measurements (used by speedup + baseline-verbatim gates) + baseline_runtime_ms: float + baseline_tokens: int + baseline_sql_canonical: str + baseline_postdrift_raises: bool + + # Drift timing (scheduled step + cooldown relative to first run_query). + drift_scheduled_step: int | None + connection_poisoned: bool = False + drift_fired_step: int | None = None + first_run_query_step: int | None = None + + # Episode progression + step_count: int = 0 + phase: EpisodePhase = EpisodePhase.DIAGNOSE + budget_steps: int = 25 + + # Per-step bookkeeping + failed_query_hashes: set[str] = field(default_factory=set) + failed_query_counts: dict[str, int] = field(default_factory=dict) + changelog_entries: list[str] = field(default_factory=list) + consultations_used: int = 0 + listed_tables_rewarded: bool = False + described_tables_rewarded: set[str] = field(default_factory=set) + sampled_tables_rewarded: set[str] = field(default_factory=set) + run_query_rewarded: bool = False + explain_query_rewarded: bool = False + changelog_rewarded_after_drift: bool = False + + # Phase-machine bookkeeping — counts successful diagnostic tool calls + # (list_tables, describe_table, sample_rows, run_query, explain_query, + # read_changelog). The DIAGNOSE → REWRITE transition fires the first + # time this becomes non-zero; SUBMIT_REWRITE is rejected while this is + # still zero (ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE). + diagnostic_actions_taken: int = 0 + + # Submission state — populated once SUBMIT_REWRITE is accepted + submitted: bool = False + submitted_sql: str | None = None + submitted_sql_canonical: str | None = None + submitted_result_hash: str | None = None + submitted_runtime_ms: float | None = None + + # Last-step signal — consumed by the rubric to compute per-step penalties + last_step_was_tool_error: bool = False + last_step_was_repeat_failing_query: bool = False + last_step_repeat_failing_query_count: int = 0 + last_step_productive_rebate: float = 0.0 + consecutive_tool_errors: int = 0 + + # Drift acknowledgement — set True the first time the agent reads the + # changelog or observes post-drift schema identifiers in a query. + drift_acknowledged: bool = False + + # DBA Oracle feature flag (read by the ConsultDBA child rubric). + # Always False unless explicitly enabled at reset (kwarg or env var). + dba_oracle_enabled: bool = False + + @property + def drift_fired(self) -> bool: + return self.drift_fired_step is not None + + @property + def budget_steps_remaining(self) -> int: + return max(0, self.budget_steps - self.step_count) diff --git a/engine/verifier.py b/engine/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..45d4120fe2c1c1d5b338323b971e114e9cc3e669 --- /dev/null +++ b/engine/verifier.py @@ -0,0 +1,101 @@ +"""Deterministic result verification. + +Canonicalizes floats to `FLOAT_PRECISION` decimal places and treats NULL +uniformly so that two result sets with the same semantic content hash to +the same digest regardless of row order, floating-point noise, or None vs +SQL NULL representation. +""" + +from __future__ import annotations + +import hashlib +from collections.abc import Iterable +from typing import Any + +NULL_SENTINEL = "\x00NULL\x00" +FLOAT_PRECISION = 6 +_DIGEST_MODULUS = 1 << 256 + + +def _normalize_value(v: Any) -> Any: + if v is None: + return NULL_SENTINEL + if isinstance(v, float): + # NaN hashes to itself here — a NaN in rows yields a deterministic + # digest, but two NaNs round-trip to different representations when + # we `repr`. Guard explicitly. + if v != v: # NaN + return NULL_SENTINEL + return round(v, FLOAT_PRECISION) + return v + + +def _row_digest_int(row: Iterable[Any]) -> int: + normalized = tuple(_normalize_value(v) for v in row) + digest = hashlib.sha256(repr(normalized).encode()).digest() + return int.from_bytes(digest, "big", signed=False) + + +def canonical_row_hash(rows: Iterable[Iterable[Any]]) -> str: + """Order-independent hash of a result set. + + This stays order-independent and duplicate-sensitive without + materializing the full result in memory. Each normalized row is + hashed once, then folded into three commutative accumulators so the + final digest is stable across row order and Python processes. + """ + row_count = 0 + sum_acc = 0 + sumsq_acc = 0 + xor_acc = 0 + for row in rows: + row_count += 1 + row_digest = _row_digest_int(row) + sum_acc = (sum_acc + row_digest) % _DIGEST_MODULUS + sumsq_acc = (sumsq_acc + ((row_digest * row_digest) % _DIGEST_MODULUS)) % _DIGEST_MODULUS + xor_acc ^= row_digest + payload = b"".join( + ( + row_count.to_bytes(32, "big", signed=False), + sum_acc.to_bytes(32, "big", signed=False), + sumsq_acc.to_bytes(32, "big", signed=False), + xor_acc.to_bytes(32, "big", signed=False), + ) + ) + return hashlib.sha256(payload).hexdigest() + + +def result_matches(agent_rows: Iterable[Iterable[Any]], gt_hash: str) -> bool: + """True if `agent_rows` canonicalizes to the ground-truth hash.""" + return canonical_row_hash(agent_rows) == gt_hash + + +def row_set_jaccard(a: Iterable[Iterable[Any]], b: Iterable[Iterable[Any]]) -> float: + """Jaccard over normalized row sets (order- and duplicate-insensitive). + + Each input row is normalised with :func:`_normalize_value` and + collapsed into a :class:`frozenset`-style Python ``set``, so rows + that repeat within a single result are counted once. This is + deliberately *not* a multiset Jaccard — multiset semantics would + punish correct queries that legitimately emit duplicates more + harshly than intended. + + Not used by the lean reward today, but kept covered by tests so + we can opt in later without rework. + """ + norm_a = {tuple(_normalize_value(v) for v in row) for row in a} + norm_b = {tuple(_normalize_value(v) for v in row) for row in b} + if not norm_a and not norm_b: + return 1.0 + union = norm_a | norm_b + inter = norm_a & norm_b + return len(inter) / len(union) + + +__all__ = [ + "FLOAT_PRECISION", + "NULL_SENTINEL", + "canonical_row_hash", + "result_matches", + "row_set_jaccard", +] diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..12b268b230524fe3d9855fb54e95ff70edf6f668 --- /dev/null +++ b/models.py @@ -0,0 +1,364 @@ +"""Public data models for SQLDrift. + +Rev 3 design notes enforced here: + +- Action is a discriminated union over a public `kind: Literal[...]` tag on + each payload sub-model. Pydantic v2 forbids leading-underscore names as + discriminator keys (reserved for private attrs), so we keep the tag public. +- `SqlDriftAction` cross-validates that the envelope-level `tool` matches + `payload.kind` (prevents inconsistent envelopes from being constructed). +- `SqlDriftObservation.tool_result` is itself a discriminated union over the + eight concrete result types plus `ToolError` (for in-env semantic failures; + envelope-level `ValidationError` is a transport-layer concern, not an in-env code). +- `SqlDriftState` is the public state snapshot shipped over `/state`. It + never carries ground truth, DB handles, baseline runtime, or seeds; + `extra="forbid"` guarantees no accidental leak as new fields are added. + The private `RuntimeEpisodeState` lives in :mod:`engine.runtime`. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Annotated, Any, Literal + +from openenv.core.env_server.types import Action, Observation, State +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic_core import PydanticCustomError + +# ============================================================================= +# Enums +# ============================================================================= + + +class ToolName(StrEnum): + LIST_TABLES = "list_tables" + DESCRIBE_TABLE = "describe_table" + SAMPLE_ROWS = "sample_rows" + RUN_QUERY = "run_query" + EXPLAIN_QUERY = "explain_query" + READ_CHANGELOG = "read_changelog" + SUBMIT_REWRITE = "submit_rewrite" + CONSULT_DBA = "consult_dba" + + +class EpisodePhase(StrEnum): + DIAGNOSE = "diagnose" + REWRITE = "rewrite" + DRIFT_RECOVERY = "drift_recovery" + FINALIZE = "finalize" + + +class ToolErrorCode(StrEnum): + """In-environment semantic failure codes (API contract). + + Envelope-level `pydantic.ValidationError` is handled by the OpenEnv + transport layer (HTTP 422 / `/ws` error frame) and never reaches + `env.step`, so it has no code here. + """ + + DB_ERROR = "db_error" + UNKNOWN_TABLE = "unknown_table" + QUERY_TIMEOUT = "query_timeout" + RESULT_TOO_LARGE = "result_too_large" + SUBMIT_BEFORE_DIAGNOSE = "submit_before_diagnose" + INVALID_TOOL_ARGUMENT = "invalid_tool_argument" + + +# ============================================================================= +# Tool payloads (request side of `SqlDriftAction`) +# ============================================================================= + + +class _BasePayload(BaseModel): + """Shared config for every tool-call payload.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class ListTablesPayload(_BasePayload): + kind: Literal["list_tables"] = "list_tables" + + +class DescribeTablePayload(_BasePayload): + kind: Literal["describe_table"] = "describe_table" + table: str = Field(min_length=1, max_length=63) + + +class SampleRowsPayload(_BasePayload): + kind: Literal["sample_rows"] = "sample_rows" + table: str = Field(min_length=1, max_length=63) + limit: int = Field(default=5, ge=1, le=5) + + +class RunQueryPayload(_BasePayload): + kind: Literal["run_query"] = "run_query" + sql: str = Field(min_length=1, max_length=10_000) + + +class ExplainQueryPayload(_BasePayload): + kind: Literal["explain_query"] = "explain_query" + sql: str = Field(min_length=1, max_length=10_000) + + +class ReadChangelogPayload(_BasePayload): + kind: Literal["read_changelog"] = "read_changelog" + + +class SubmitRewritePayload(_BasePayload): + kind: Literal["submit_rewrite"] = "submit_rewrite" + sql: str = Field(min_length=1, max_length=10_000) + + +class ConsultDBAPayload(_BasePayload): + kind: Literal["consult_dba"] = "consult_dba" + question: str = Field(min_length=1, max_length=400) + + +ToolPayload = Annotated[ + ListTablesPayload + | DescribeTablePayload + | SampleRowsPayload + | RunQueryPayload + | ExplainQueryPayload + | ReadChangelogPayload + | SubmitRewritePayload + | ConsultDBAPayload, + Field(discriminator="kind"), +] + + +# Tool -> payload-kind mapping; single source of truth for cross-validation +# and for the server-side dispatcher in P7. +TOOL_TO_PAYLOAD_KIND: dict[ToolName, str] = { + ToolName.LIST_TABLES: "list_tables", + ToolName.DESCRIBE_TABLE: "describe_table", + ToolName.SAMPLE_ROWS: "sample_rows", + ToolName.RUN_QUERY: "run_query", + ToolName.EXPLAIN_QUERY: "explain_query", + ToolName.READ_CHANGELOG: "read_changelog", + ToolName.SUBMIT_REWRITE: "submit_rewrite", + ToolName.CONSULT_DBA: "consult_dba", +} + + +# ============================================================================= +# SqlDriftAction envelope +# ============================================================================= + + +class SqlDriftAction(Action): + """Tool-call envelope. + + JSON wire format:: + + {"tool": "run_query", "payload": {"kind": "run_query", "sql": "..."}} + + The `tool` field and `payload.kind` must agree; mismatch raises at + validation time. + """ + + tool: ToolName + payload: ToolPayload + + @model_validator(mode="after") + def _tool_matches_payload(self) -> SqlDriftAction: + expected = TOOL_TO_PAYLOAD_KIND[self.tool] + if self.payload.kind != expected: + # PydanticCustomError keeps ``ctx`` JSON-serializable (plain + # strings only), unlike a bare ``ValueError`` which Pydantic + # wraps with ``ctx={"error": ValueError(...)}`` and breaks + # FastAPI HTTPException JSON encoder (422 responses). + raise PydanticCustomError( + "tool_payload_mismatch", + "tool/payload mismatch: tool={tool} expects payload.kind={expected}, got {got}", + { + "tool": self.tool.value, + "expected": expected, + "got": self.payload.kind, + }, + ) + return self + + +# ============================================================================= +# Tool results (response side of `SqlDriftObservation.tool_result`) +# ============================================================================= + + +class _BaseResult(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class ListTablesResult(_BaseResult): + kind: Literal["list_tables_result"] = "list_tables_result" + tables: list[str] + + +class DescribeTableResult(_BaseResult): + kind: Literal["describe_table_result"] = "describe_table_result" + table: str + columns: list[dict[str, str]] # [{"name": "...", "type": "..."}] + + +class SampleRowsResult(_BaseResult): + kind: Literal["sample_rows_result"] = "sample_rows_result" + table: str + columns: list[str] + rows: list[list[Any]] + + +class RunQueryResult(_BaseResult): + kind: Literal["run_query_result"] = "run_query_result" + columns: list[str] + rows: list[list[Any]] + runtime_ms: float + row_count: int + + +class ExplainQueryResult(_BaseResult): + kind: Literal["explain_query_result"] = "explain_query_result" + plan: str + + +class ReadChangelogResult(_BaseResult): + kind: Literal["read_changelog_result"] = "read_changelog_result" + entries: list[str] + + +class SubmitRewriteResult(_BaseResult): + kind: Literal["submit_rewrite_result"] = "submit_rewrite_result" + accepted: bool + runtime_ms: float + matches_ground_truth: bool + + +class ConsultDBAResult(_BaseResult): + kind: Literal["consult_dba_result"] = "consult_dba_result" + tier: int = Field(ge=1, le=3) + hint: str + + +class ToolError(_BaseResult): + kind: Literal["tool_error"] = "tool_error" + code: ToolErrorCode + message: str = Field(max_length=2_000) + + +ToolResult = Annotated[ + ListTablesResult + | DescribeTableResult + | SampleRowsResult + | RunQueryResult + | ExplainQueryResult + | ReadChangelogResult + | SubmitRewriteResult + | ConsultDBAResult + | ToolError, + Field(discriminator="kind"), +] + + +# The six reward-component keys match the composed rubric; tests and telemetry +# rely on this exact schema. +REWARD_COMPONENT_KEYS: tuple[str, ...] = ( + "r_correct", + "r_drift", + "r_speedup", + "r_step_tax", + "r_gatekeepers", + "r_consult_dba", +) + + +# ============================================================================= +# SqlDriftObservation +# ============================================================================= + + +def _zero_reward_components() -> dict[str, float]: + """Six-key reward envelope initialised to zero. + + Every observation, including the reset observation, carries the full + six-key schema so telemetry and tests can index it unconditionally. + """ + return {key: 0.0 for key in REWARD_COMPONENT_KEYS} + + +class SqlDriftObservation(Observation): + """Observation returned by :meth:`SqlDriftEnvironment.step`. + + Inherits `done: bool` and `reward: float | None` from base Observation. + + The task payload (`baseline_sql`, `schema_synopsis`) is delivered on + the reset observation and kept empty on subsequent steps: the agent + is expected to capture it once and hold it in its own context. + """ + + step: int = Field(ge=0) + phase: EpisodePhase + last_tool: ToolName | None = None + tool_result: ToolResult | None = None + drift_fired: bool = False + drift_acknowledged: bool = False + learned_hints: str = Field(default="", max_length=800) + baseline_sql: str = Field(default="", max_length=10_000) + schema_synopsis: str = Field(default="", max_length=2_000) + budget_steps_remaining: int = Field(ge=0) + reward_components: dict[str, float] = Field(default_factory=_zero_reward_components) + + +# ============================================================================= +# SqlDriftState — PUBLIC state (sanitized) +# ============================================================================= + + +class SqlDriftState(State): + """Public state snapshot — serialized over `/state`. + + Ground truth, DB handles, seeds, and baseline SQL live in + :class:`engine.runtime.RuntimeEpisodeState` and are never exposed here. + `extra="forbid"` guarantees no accidental leak via future field additions. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + scenario_id: str + phase: EpisodePhase + budget_steps_remaining: int = Field(ge=0) + drift_fired: bool = False + consultations_used: int = Field(default=0, ge=0) + submitted: bool = False + + +__all__ = [ + "ConsultDBAPayload", + "ConsultDBAResult", + "DescribeTablePayload", + "DescribeTableResult", + "EpisodePhase", + "ExplainQueryPayload", + "ExplainQueryResult", + "ListTablesPayload", + "ListTablesResult", + "REWARD_COMPONENT_KEYS", + "ReadChangelogPayload", + "ReadChangelogResult", + "RunQueryPayload", + "RunQueryResult", + "SampleRowsPayload", + "SampleRowsResult", + "SqlDriftAction", + "SqlDriftObservation", + "SqlDriftState", + "SubmitRewritePayload", + "SubmitRewriteResult", + "TOOL_TO_PAYLOAD_KIND", + "ToolError", + "ToolErrorCode", + "ToolName", + "ToolPayload", + "ToolResult", +] diff --git a/openenv.yaml b/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45a9dd4f81408d3c6c1767d551a1755bde61b6da --- /dev/null +++ b/openenv.yaml @@ -0,0 +1,7 @@ +spec_version: 1 +name: sql_drift_env +type: space +runtime: fastapi +app: server.app:app +port: 8000 + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6a33d05032aabe7d718922017d5f679793b00f06 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sql_drift_env" +version = "0.1.0" +description = "SQLDrift: OpenEnv gym for repairing and optimizing SQL under live schema/business-rule drift" +requires-python = ">=3.12,<3.14" +dependencies = [ + "duckdb>=1.5.2,<2.0", + "huggingface-hub>=0.20,<1.0", + "openenv-core[core]>=0.2.2,<0.4", + "sqlglot>=30.6.0,<40.0", + "pydantic>=2.8.0,<3.0", + "python-dotenv>=1.2.2,<2.0", + "openai>=2.32.0,<3.0", +] + +[project.optional-dependencies] +evidence = [ + "matplotlib>=3.8.0,<4.0", + "pandas>=2.0.0,<3.0", +] +dev = [ + "mypy>=1.20.1", + "pytest>=9.0.3", + "pytest-asyncio>=1.3.0", + "pytest-cov>=7.0.0", + "ruff>=0.15.11", + "httpx>=0.28.0", +] +train = [ + "trl>=0.25.0,<1.0", + "datasets>=2.20.0,<4.0", + "transformers>=4.46.0,<5.0", + "accelerate>=1.0.0,<2.0", + "peft>=0.13,<1.0", + "bitsandbytes>=0.43,<1.0", + "unsloth>=2024.9,<2027.0", +] + +[project.scripts] +# Flat package layout; see _cli.py. Runtime path: site-packages/sql_drift_env/ +server = "sql_drift_env._cli:main" + +[tool.setuptools] +include-package-data = true +packages = [ + "sql_drift_env", + "sql_drift_env.scenarios", + "sql_drift_env.engine", + "sql_drift_env.skill_library", + "sql_drift_env.actors", + "sql_drift_env.training", + "sql_drift_env.server", + "sql_drift_env.utilities", +] + +[tool.setuptools.package-dir] +"sql_drift_env" = "." +"sql_drift_env.scenarios" = "scenarios" +"sql_drift_env.engine" = "engine" +"sql_drift_env.skill_library" = "skill_library" +"sql_drift_env.actors" = "actors" +"sql_drift_env.training" = "training" +"sql_drift_env.server" = "server" +"sql_drift_env.utilities" = "utilities" + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +markers = [ + "slow: marks tests as slow (excluded by default; run via `pytest -m slow`)", +] +# Default excludes slow tests so CI stays on a fast path; run `pytest -m slow` for full suite. +addopts = "-ra --strict-markers -m 'not slow'" + +[tool.ruff] +line-length = 100 +target-version = "py312" +extend-exclude = [".venv", "design", "*.egg-info", "*.ipynb"] + +[tool.ruff.lint] +select = ["E", "F", "I", "B", "UP", "SIM"] +ignore = ["E501"] + +# Prefer fixing structure over scattered inline suppressions. See each rule in Ruff docs. +[tool.ruff.lint.per-file-ignores] +"_cli.py" = ["PLC0415"] +"training/grpo_train.py" = ["PLC0415"] +"engine/profiler.py" = ["BLE001"] +"scenarios/__init__.py" = ["F401"] +"tests/unit/test_p0_smoke.py" = ["F401"] +"utilities/verbose_api_rollout.py" = ["E402"] +"utilities/demo_rollout.py" = ["E402"] +"tests/unit/test_drift_scenarios.py" = ["E402"] +"tests/unit/test_profiler.py" = ["E402"] +"tests/unit/test_drift.py" = ["E402"] +"tests/unit/test_runtime.py" = ["E402"] + +[tool.mypy] +python_version = "3.12" +strict = true +ignore_missing_imports = true +explicit_package_bases = true +mypy_path = "." +files = [ + "_cli.py", + "actors", + "client.py", + "engine", + "models.py", + "scenarios", + "utilities", + "server", + "skill_library", + "training", +] +exclude = [ + "\\.venv", + "design", + ".*\\.egg-info", + "tests", + "build", + "dist", +] + +# Subclasses in these modules call into openenv-core (no stubs). Relax only here. +[[tool.mypy.overrides]] +module = [ + "client", + "models", + "engine.reward", + "server.app", + "server.sql_drift_env_environment", + "training.grpo_train", +] +disallow_subclassing_any = false +warn_return_any = false + +[tool.coverage.run] +# pytest-cov can hit upstream issues with beartype when tracing; run plain pytest for CI. +source = ["engine", "scenarios", "skill_library", "actors", "models.py"] +omit = ["tests/*", ".venv/*"] + +[tool.coverage.report] +skip_empty = true +exclude_lines = [ + "pragma: no cover", + "if __name__ == \"__main__\":", + "if TYPE_CHECKING:", + "raise NotImplementedError", + "if 0:", + "if False:", + "\\.\\.\\.", +] diff --git a/scenarios/__init__.py b/scenarios/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffd54f567a8b1c3f81f6b6ee77f007700d0073a --- /dev/null +++ b/scenarios/__init__.py @@ -0,0 +1,72 @@ +"""Scenario registry + auto-discovery. + +Every sibling module whose name starts with ``scenario_`` is expected to +export a module-level ``SPEC: ScenarioSpec``. This module walks the package +at import time and populates :data:`REGISTRY` so callers (env reset, tests, +CLI eval) can look scenarios up by id without knowing the file layout. +""" + +from __future__ import annotations + +import importlib +import pkgutil + +from .base import ( + DriftConfig, + DriftKind, + Family, + ScenarioInstance, + ScenarioSpec, +) + +REGISTRY: dict[str, ScenarioSpec] = {} + + +def _discover_specs() -> None: + """Import every `scenario_*` sibling module and harvest their SPEC.""" + package = __name__ # "scenarios" + package_path = __path__ # provided by Python's package machinery + + for info in pkgutil.iter_modules(package_path): + if not info.name.startswith("scenario_"): + continue + module = importlib.import_module(f"{package}.{info.name}") + spec = getattr(module, "SPEC", None) + if spec is None: + raise RuntimeError(f"{package}.{info.name} is missing a module-level `SPEC` export") + if not isinstance(spec, ScenarioSpec): + raise TypeError( + f"{package}.{info.name}.SPEC is {type(spec).__name__}; expected ScenarioSpec" + ) + if spec.scenario_id in REGISTRY: + raise RuntimeError( + f"duplicate scenario_id {spec.scenario_id!r} — " + f"already registered from {REGISTRY[spec.scenario_id]!r}" + ) + REGISTRY[spec.scenario_id] = spec + + +def iter_specs() -> list[ScenarioSpec]: + return sorted(REGISTRY.values(), key=lambda s: s.scenario_id) + + +def get_spec(scenario_id: str) -> ScenarioSpec: + try: + return REGISTRY[scenario_id] + except KeyError as e: + raise KeyError(f"unknown scenario_id={scenario_id!r}; known: {sorted(REGISTRY)}") from e + + +_discover_specs() + + +__all__ = [ + "DriftConfig", + "DriftKind", + "Family", + "REGISTRY", + "ScenarioInstance", + "ScenarioSpec", + "get_spec", + "iter_specs", +] diff --git a/scenarios/_fixtures.py b/scenarios/_fixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5f4f8f675b059bb54b1048d4531a9bd05c7e6d --- /dev/null +++ b/scenarios/_fixtures.py @@ -0,0 +1,119 @@ +"""Deterministic fixture generators (stdlib only). + +All distributions are backed by :class:`random.Random(seed)` so a given +``(scenario_id, seed, scale)`` tuple always yields the same table contents. + +Exposes a single :func:`seeded_rng` factory plus a handful of domain-specific +generators used by the concrete scenarios. No numpy/pandas runtime deps. +""" + +from __future__ import annotations + +import math +import random +from collections.abc import Sequence +from datetime import UTC + + +def seeded_rng(*parts: int | str) -> random.Random: + """Derive a deterministic `random.Random` from mixed scalar parts. + + Uses a stable 64-bit SplitMix-style hash over ``repr(parts)`` — avoids + Python's per-interpreter salted ``hash()`` for str. + """ + h = 1469598103934665603 # FNV-1a 64-bit offset basis + for p in parts: + for byte in repr(p).encode(): + h ^= byte + h = (h * 1099511628211) & 0xFFFFFFFFFFFFFFFF + return random.Random(h) + + +def lognormal_amounts(rng: random.Random, n: int, mu: float, sigma: float) -> list[float]: + """n log-normally distributed positive amounts, rounded to 2dp.""" + return [round(rng.lognormvariate(mu, sigma), 2) for _ in range(n)] + + +def zipfian_choices( + rng: random.Random, population: Sequence[int], n: int, *, alpha: float = 1.07 +) -> list[int]: + """n draws from `population` with zipf(alpha) weights. + + `alpha` controls skew: 1.07 is the project default. Uses + `random.choices` with explicit weights, so the draws are stable across + Python versions for a given `rng` state. + """ + weights = [1.0 / ((i + 1) ** alpha) for i in range(len(population))] + return rng.choices(list(population), weights=weights, k=n) + + +def date_range_epoch_ms( + rng: random.Random, + n: int, + *, + start_epoch_ms: int, + window_days: int, +) -> list[int]: + """n random timestamps (ms) within `[start, start + window_days)`.""" + span_ms = window_days * 86_400_000 + return [start_epoch_ms + rng.randrange(span_ms) for _ in range(n)] + + +def iso_strings_from_epoch_ms(epoch_ms: list[int]) -> list[str]: + """Convert epoch ms to ISO-8601 UTC strings (matching DuckDB's native coerce).""" + from datetime import datetime + + return [ + datetime.fromtimestamp(t / 1000, tz=UTC).isoformat().replace("+00:00", "Z") + for t in epoch_ms + ] + + +def categorical_choices( + rng: random.Random, + categories: Sequence[str], + n: int, + *, + weights: Sequence[float] | None = None, +) -> list[str]: + return rng.choices(list(categories), weights=list(weights) if weights else None, k=n) + + +def unique_names(rng: random.Random, n: int, *, prefix: str = "name") -> list[str]: + """Stable pseudo-unique string IDs of the form `_<64-bit-hex>`.""" + return [f"{prefix}_{rng.getrandbits(64):016x}" for _ in range(n)] + + +def approx_normal( + rng: random.Random, n: int, *, mu: float, sigma: float, clip_lo: float | None = None +) -> list[float]: + """n normal draws, optionally clipped below.""" + out: list[float] = [] + for _ in range(n): + x = rng.gauss(mu, sigma) + if clip_lo is not None and x < clip_lo: + x = clip_lo + out.append(round(x, 4)) + return out + + +def sanity_nonzero_variance(xs: Sequence[float]) -> bool: + """Guard: reject obviously degenerate distributions (used in smoke tests).""" + if not xs: + return False + mean = sum(xs) / len(xs) + var = sum((x - mean) ** 2 for x in xs) / len(xs) + return var > 1e-12 and not math.isnan(var) + + +__all__ = [ + "approx_normal", + "categorical_choices", + "date_range_epoch_ms", + "iso_strings_from_epoch_ms", + "lognormal_amounts", + "sanity_nonzero_variance", + "seeded_rng", + "unique_names", + "zipfian_choices", +] diff --git a/scenarios/base.py b/scenarios/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ea4b8c5f9f2492dd8597102417ffd190b0f825 --- /dev/null +++ b/scenarios/base.py @@ -0,0 +1,230 @@ +"""Scenario spec + instance primitives. + +Each concrete scenario file in :mod:`scenarios` exports: + +- ``SPEC: ScenarioSpec`` — the immutable metadata (id, family, tags, + optional drift config) plus a bound ``builder`` callable. + +The builder takes ``(spec, seed, scale)`` and returns a ready-to-attach +:class:`ScenarioInstance` whose DuckDB connection has been loaded with +deterministic fixtures, ground-truth hashes pre-computed, and baseline +runtime measured. ``base_scale`` is author-tuned per scenario so the +measured baseline clears :data:`BASELINE_MIN_MS` on a single build — +the old timing-driven reroll loop was removed because it coupled the +fixture RNG seed to the retry count, which destroyed determinism +whenever CI hit a jitter-induced retry. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + import duckdb + +from utilities.logger import get_module_logger + +_LOG = get_module_logger(__name__) + +Family = Literal["ecommerce", "events", "cms", "saas_logs", "multitenant"] +DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"] +DifficultyLevel = Literal["easy", "normal", "hard"] + +# Baseline floor — empirically calibrated, not a fixed marketing target +# 50 ms. DuckDB on in-memory fixtures at CI-reasonable scales (a few +# tens of thousands of rows) measures warm baselines of 0.4–2 ms on +# the anti-pattern queries shipped here; reaching 50 ms would require +# multi-minute fixture builds per scenario, which is untenable for +# both CI and RL rollouts (every reset rebuilds). +# +# 0.3 ms is ~3–5× the median-of-3 warm jitter floor on a quiet CPU +# (observed jitter ~60–100 µs). This SNR is tight but workable because +# the rubric gates the speedup reward at 1.2× before any +# credit is issued, so jitter-induced near-1× "speedups" score zero. +# The cap at 64× bounds upside. A 2× rewrite against a 0.3 ms baseline +# lands at 0.15 ms — still distinguishable from jitter under +# median-of-3 smoothing. +# +# The same floor applies in production and CI — no env-var escape +# hatch — so tests exercise the real reward distribution. Per-scenario +# overrides may raise *or* lower this floor when a scenario's query +# shape has a different natural baseline (see the field docstring on +# :class:`ScenarioSpec.baseline_min_ms`). +BASELINE_MIN_MS = 0.3 + + +@dataclass(frozen=True) +class DriftConfig: + kind: DriftKind + payload: dict[str, Any] + min_step: int = 6 + max_step: int = 12 + cooldown_steps: int = 2 + + def __post_init__(self) -> None: + if self.min_step < 1: + raise ValueError("min_step must be >= 1") + if self.max_step < self.min_step: + raise ValueError("max_step must be >= min_step") + if self.cooldown_steps < 0: + raise ValueError("cooldown_steps must be >= 0") + + +@dataclass +class ScenarioInstance: + """Concretized scenario — ready-to-attach DuckDB fixture + ground truths.""" + + conn: duckdb.DuckDBPyConnection + baseline_sql: str + gt_sql_predrift: str + gt_sql_postdrift: str | None + baseline_runtime_ms: float + baseline_tokens: int + gt_result_hash_predrift: str + gt_result_hash_postdrift: str | None + drift_config: DriftConfig | None + schema_synopsis: str + # Drift-distinctive identifier sets consumed by the drift-adapt + # rubric. ``postdrift_identifiers`` marks identifiers/literals + # the correct post-drift rewrite MUST introduce; ``predrift_identifiers`` + # marks identifiers/literals a submission that ignored the drift + # WOULD retain. Together they let the rubric distinguish "adapted" + # from "did not adapt" for drift kinds where a single identifier + # (e.g. ``ts`` under date-format drift) is shared by both sides. + postdrift_identifiers: frozenset[str] = field(default_factory=frozenset) + predrift_identifiers: frozenset[str] = field(default_factory=frozenset) + + +# Builder signature: (spec, seed, scale) -> (conn, baseline_sql, +# gt_sql_predrift, gt_sql_postdrift, schema_synopsis, +# postdrift_identifiers, predrift_identifiers). +BuilderResult = tuple[ + "duckdb.DuckDBPyConnection", + str, # baseline_sql + str, # gt_sql_predrift + str | None, # gt_sql_postdrift + str, # schema_synopsis + frozenset[str], # postdrift_identifiers + frozenset[str], # predrift_identifiers +] +BuilderFn = Callable[["ScenarioSpec", int, int], BuilderResult] + + +@dataclass(frozen=True) +class ScenarioSpec: + """Immutable scenario metadata + bound builder.""" + + scenario_id: str + family: Family + tags: frozenset[str] + drift_config: DriftConfig | None + builder: BuilderFn + # Row-count scale passed to the builder. Author-tuned so the + # measured baseline clears ``baseline_min_ms`` on a single build; + # materialize() emits a warning (but does not retry) if the floor + # is not met, signalling the author to bump this value. + base_scale: int = 1_000 + # Per-scenario baseline floor override. Most scenarios inherit the + # module default. Scenarios whose query shape naturally lands at a + # very different baseline (e.g. a trivial single-table GROUP BY + # that can't be meaningfully sped up, or a large join whose raw + # shape is already expensive) can pin a different floor with a + # documented rationale at the SPEC site. + baseline_min_ms: float = BASELINE_MIN_MS + + def materialize(self, seed: int, *, difficulty: DifficultyLevel = "normal") -> ScenarioInstance: + return materialize(self, seed, difficulty=difficulty) + + +def count_tokens(sql: str) -> int: + """Rough whitespace/punctuation token count — good enough for baseline.""" + import re + + return len(re.findall(r"[\w]+|[^\s\w]", sql)) + + +def _scale_for_difficulty(base_scale: int, difficulty: DifficultyLevel) -> int: + """Map a coarse difficulty level onto the scenario builder's row-count scale.""" + if difficulty == "easy": + return max(1, base_scale // 2) + if difficulty == "hard": + return base_scale * 2 + return base_scale + + +def materialize( + spec: ScenarioSpec, seed: int, *, difficulty: DifficultyLevel = "normal" +) -> ScenarioInstance: + """Build a ScenarioInstance once, measure baseline, and return. + + Single build — deterministic, no retry. If the measured baseline is + below ``spec.baseline_min_ms`` a warning is logged so scenario + authors can bump ``base_scale``; the instance is still returned so + episodes can proceed (the rubric gracefully handles small + baselines via the 1.2× speedup gate and infinite-speedup cap). + """ + from engine.profiler import median_of_3_warm_ms + from engine.verifier import canonical_row_hash + + scale = _scale_for_difficulty(spec.base_scale, difficulty) + + ( + conn, + baseline_sql, + gt_pre, + gt_post, + synopsis, + postdrift_ids, + predrift_ids, + ) = spec.builder(spec, seed, scale) + try: + baseline_ms = median_of_3_warm_ms(conn, baseline_sql) + except Exception: + conn.close() + raise + if baseline_ms < spec.baseline_min_ms: + _LOG.warning( + "%s: baseline %.2fms < %.2fms floor at difficulty=%s scale=%d — bump base_scale", + spec.scenario_id, + baseline_ms, + spec.baseline_min_ms, + difficulty, + scale, + ) + + pre_rows = conn.execute(gt_pre).fetchall() + gt_hash_pre = canonical_row_hash(pre_rows) + # Post-drift ground-truth hashes are computed AFTER drift is applied + # at runtime — not here. The env backfills them from gt_post once + # drift fires. + return ScenarioInstance( + conn=conn, + baseline_sql=baseline_sql, + gt_sql_predrift=gt_pre, + gt_sql_postdrift=gt_post, + baseline_runtime_ms=baseline_ms, + baseline_tokens=count_tokens(baseline_sql), + gt_result_hash_predrift=gt_hash_pre, + gt_result_hash_postdrift=None, + drift_config=spec.drift_config, + schema_synopsis=synopsis, + postdrift_identifiers=postdrift_ids, + predrift_identifiers=predrift_ids, + ) + + +__all__ = [ + "BASELINE_MIN_MS", + "BuilderFn", + "BuilderResult", + "DifficultyLevel", + "DriftConfig", + "DriftKind", + "Family", + "ScenarioInstance", + "ScenarioSpec", + "count_tokens", + "materialize", +] diff --git a/scenarios/scenario_01_correlated_subquery.py b/scenarios/scenario_01_correlated_subquery.py new file mode 100644 index 0000000000000000000000000000000000000000..7f3ad2eb2ee82398eb3b6f294764dd7d103fb08e --- /dev/null +++ b/scenarios/scenario_01_correlated_subquery.py @@ -0,0 +1,93 @@ +"""Scenario 01 — correlated subquery anti-pattern. + +Baseline computes a per-user order-count via a correlated subquery in the +projection list. The rewrite replaces it with a LEFT JOIN on a grouped +aggregate so the query runs once instead of once per outer row. + +Schema: users(id, signup_month, tier), orders(id, user_id, amount, status). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import ( + categorical_choices, + lognormal_amounts, + seeded_rng, + zipfian_choices, +) +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_users = scale + n_orders = scale * 6 + + user_ids = list(range(1, n_users + 1)) + tiers = categorical_choices( + rng, ["free", "pro", "business"], n_users, weights=[0.7, 0.25, 0.05] + ) + signup_months = rng.choices(list(range(1, 13)), k=n_users) + + order_user_ids = zipfian_choices(rng, user_ids, n_orders) + amounts = lognormal_amounts(rng, n_orders, mu=3.2, sigma=0.9) + statuses = categorical_choices( + rng, + ["placed", "fulfilled", "refunded", "cancelled"], + n_orders, + weights=[0.6, 0.3, 0.05, 0.05], + ) + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE users( id BIGINT PRIMARY KEY, signup_month INTEGER, tier VARCHAR);") + conn.execute( + "CREATE TABLE orders(" + " id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE, status VARCHAR);" + ) + conn.executemany( + "INSERT INTO users VALUES (?, ?, ?)", + list(zip(user_ids, signup_months, tiers, strict=False)), + ) + conn.executemany( + "INSERT INTO orders VALUES (?, ?, ?, ?)", + [ + (oid, uid, amt, st) + for oid, (uid, amt, st) in enumerate( + zip(order_user_ids, amounts, statuses, strict=False), start=1 + ) + ], + ) + + baseline_sql = ( + "SELECT u.id, u.tier, " + "(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id AND o.status = 'fulfilled') " + " AS fulfilled_orders " + "FROM users u " + "ORDER BY u.id" + ) + gt_sql_predrift = ( + "SELECT u.id, u.tier, COALESCE(c.n, 0) AS fulfilled_orders " + "FROM users u LEFT JOIN (" + " SELECT user_id, COUNT(*) AS n FROM orders " + " WHERE status = 'fulfilled' GROUP BY user_id" + ") c ON c.user_id = u.id " + "ORDER BY u.id" + ) + + synopsis = ( + "users(id PK, signup_month, tier); orders(id PK, user_id→users.id, amount, status). " + "Baseline scans orders once per user via a correlated subquery." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="01_correlated_subquery", + family="ecommerce", + tags=frozenset({"correlated_subquery", "projection_subquery", "ecommerce"}), + drift_config=None, + builder=_build, + base_scale=800, +) diff --git a/scenarios/scenario_02_select_star_join.py b/scenarios/scenario_02_select_star_join.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb53a91f4f630bc44239445e0e938be8cddb0c0 --- /dev/null +++ b/scenarios/scenario_02_select_star_join.py @@ -0,0 +1,119 @@ +"""Scenario 02 — SELECT * over a join anti-pattern. + +Baseline pulls every column from three joined tables when only two columns +are needed. The rewrite projects exactly the requested columns. + +Schema: products(id, sku, name, description, price_cents, …), orders(id, user_id, created_at), + order_items(order_id, product_id, qty, unit_price_cents). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import ( + categorical_choices, + lognormal_amounts, + seeded_rng, + unique_names, + zipfian_choices, +) +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_products = max(100, scale // 4) + n_orders = scale + n_items = scale * 3 + n_users = max(50, scale // 8) + + product_ids = list(range(1, n_products + 1)) + skus = unique_names(rng, n_products, prefix="sku") + names = unique_names(rng, n_products, prefix="p") + descriptions = [f"Long marketing copy for {n}" * 6 for n in names] # wide col + prices = [int(x * 100) for x in lognormal_amounts(rng, n_products, mu=3.0, sigma=1.1)] + categories = categorical_choices( + rng, ["books", "electronics", "apparel", "grocery", "home"], n_products + ) + + user_ids = list(range(1, n_users + 1)) + order_user_ids = zipfian_choices(rng, user_ids, n_orders) + order_created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)] + + item_order_ids = rng.choices(list(range(1, n_orders + 1)), k=n_items) + item_product_ids = zipfian_choices(rng, product_ids, n_items) + item_qty = rng.choices([1, 1, 1, 2, 2, 3, 4, 5], k=n_items) + item_unit_price = [prices[pid - 1] for pid in item_product_ids] + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE products(" + " id BIGINT PRIMARY KEY, sku VARCHAR, name VARCHAR, description VARCHAR," + " price_cents BIGINT, category VARCHAR);" + ) + conn.execute( + "CREATE TABLE orders( id BIGINT PRIMARY KEY, user_id BIGINT, created_at_epoch_s BIGINT);" + ) + conn.execute( + "CREATE TABLE order_items(" + " order_id BIGINT, product_id BIGINT, qty INTEGER, unit_price_cents BIGINT);" + ) + conn.executemany( + "INSERT INTO products VALUES (?, ?, ?, ?, ?, ?)", + list(zip(product_ids, skus, names, descriptions, prices, categories, strict=False)), + ) + conn.executemany( + "INSERT INTO orders VALUES (?, ?, ?)", + [ + (oid, uid, ts) + for oid, (uid, ts) in enumerate(zip(order_user_ids, order_created, strict=False), 1) + ], + ) + conn.executemany( + "INSERT INTO order_items VALUES (?, ?, ?, ?)", + list(zip(item_order_ids, item_product_ids, item_qty, item_unit_price, strict=False)), + ) + + # Anti-pattern: SELECT * inside a subquery that wraps the real join, then + # the outer query projects only a handful of columns. DuckDB materializes + # every column of the subquery before the projection can prune it. + baseline_sql = ( + "SELECT t.order_id, t.name, t.qty " + "FROM (" + " SELECT * FROM order_items oi " + " JOIN products p ON p.id = oi.product_id " + " JOIN orders o ON o.id = oi.order_id " + " WHERE p.category = 'books' AND oi.qty >= 2" + ") t " + "ORDER BY t.order_id, t.name" + ) + gt_sql_predrift = ( + "SELECT oi.order_id, p.name, oi.qty " + "FROM order_items oi " + "JOIN products p ON p.id = oi.product_id " + "JOIN orders o ON o.id = oi.order_id " + "WHERE p.category = 'books' AND oi.qty >= 2 " + "ORDER BY oi.order_id, p.name" + ) + + synopsis = ( + "products(id PK, sku, name, description, price_cents, category); " + "orders(id PK, user_id, created_at_epoch_s); " + "order_items(order_id, product_id, qty, unit_price_cents). " + "Baseline wraps a three-way join with SELECT * inside a subquery; " + "only (order_id, product name, qty) are needed downstream." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="02_select_star_join", + family="ecommerce", + tags=frozenset({"select_star", "over_projection", "join", "ecommerce"}), + drift_config=None, + builder=_build, + # Three-way join with SELECT * is already well-optimized by DuckDB; + # a larger base keeps first-try baseline above the 1 ms floor. + base_scale=1_500, +) diff --git a/scenarios/scenario_03_cartesian_join.py b/scenarios/scenario_03_cartesian_join.py new file mode 100644 index 0000000000000000000000000000000000000000..29696c6278f9727c332509d4c06b31cc10ed800b --- /dev/null +++ b/scenarios/scenario_03_cartesian_join.py @@ -0,0 +1,93 @@ +"""Scenario 03 — accidental cartesian join. + +Baseline joins events × tenants without an equijoin clause and relies on a +post-filter `WHERE` to restore tenant scoping. The optimizer cannot push +the filter into the join, so it materializes the full |events|·|tenants| +cross-product first. The rewrite moves the tenant key into an ON clause. + +Schema: tenants(id, tier), events(id, tenant_id, kind, severity). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import categorical_choices, seeded_rng, zipfian_choices +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_tenants = max(20, scale // 40) + n_events = scale * 4 + + tenant_ids = list(range(1, n_tenants + 1)) + tiers = categorical_choices( + rng, ["free", "pro", "enterprise"], n_tenants, weights=[0.6, 0.3, 0.1] + ) + event_tenant_ids = zipfian_choices(rng, tenant_ids, n_events) + kinds = categorical_choices( + rng, + ["login", "logout", "action", "error"], + n_events, + weights=[0.35, 0.3, 0.3, 0.05], + ) + severities = categorical_choices( + rng, + ["info", "warn", "error", "critical"], + n_events, + weights=[0.7, 0.2, 0.08, 0.02], + ) + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE tenants(id BIGINT PRIMARY KEY, tier VARCHAR);") + conn.execute( + "CREATE TABLE events(" + " id BIGINT PRIMARY KEY, tenant_id BIGINT, kind VARCHAR, severity VARCHAR);" + ) + conn.executemany( + "INSERT INTO tenants VALUES (?, ?)", list(zip(tenant_ids, tiers, strict=False)) + ) + conn.executemany( + "INSERT INTO events VALUES (?, ?, ?, ?)", + [ + (i, tid, k, s) + for i, (tid, k, s) in enumerate( + zip(event_tenant_ids, kinds, severities, strict=False), start=1 + ) + ], + ) + + baseline_sql = ( + "SELECT t.tier, COUNT(*) AS n " + "FROM events e, tenants t " + "WHERE t.id = e.tenant_id + 0 " # defeat optimizer recognition + "AND e.severity IN ('error', 'critical') " + "GROUP BY t.tier ORDER BY t.tier" + ) + gt_sql_predrift = ( + "SELECT t.tier, COUNT(*) AS n " + "FROM events e JOIN tenants t ON t.id = e.tenant_id " + "WHERE e.severity IN ('error', 'critical') " + "GROUP BY t.tier ORDER BY t.tier" + ) + + synopsis = ( + "tenants(id PK, tier); events(id PK, tenant_id→tenants.id, kind, severity). " + "Baseline relies on a WHERE-clause equijoin obscured by arithmetic, " + "forcing a cartesian materialization." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="03_cartesian_join", + family="events", + tags=frozenset({"cartesian", "missing_join_condition", "events"}), + drift_config=None, + builder=_build, + # Cartesian materialization cost grows with |events|·|tenants| so + # larger base_scale keeps first-try baseline above the 1 ms floor; + # the reroll loop doubles from here if needed. + base_scale=1_200, +) diff --git a/scenarios/scenario_04_distinct_groupby.py b/scenarios/scenario_04_distinct_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..33a4358c3b0144e59338408005d35ee9869f9adc --- /dev/null +++ b/scenarios/scenario_04_distinct_groupby.py @@ -0,0 +1,73 @@ +"""Scenario 04 — redundant DISTINCT on top of GROUP BY. + +Baseline applies DISTINCT after already grouping, forcing a second +deduplication over a result whose GROUP BY keys are already unique. +Rewrite drops the DISTINCT entirely. + +Schema: pageviews(id, session_id, path, ts_epoch_s). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_sessions = max(200, scale // 4) + n_views = scale * 8 + + session_ids = unique_names(rng, n_sessions, prefix="sess") + paths = categorical_choices( + rng, + [f"/path/{p}" for p in ["home", "about", "product", "cart", "checkout", "help"]], + n_views, + ) + view_sessions = zipfian_choices(rng, list(range(n_sessions)), n_views) + view_session_ids = [session_ids[i] for i in view_sessions] + view_ts = [1_700_000_000 + rng.randrange(30 * 86_400) for _ in range(n_views)] + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE pageviews(" + " id BIGINT PRIMARY KEY, session_id VARCHAR, path VARCHAR, ts_epoch_s BIGINT);" + ) + conn.executemany( + "INSERT INTO pageviews VALUES (?, ?, ?, ?)", + [ + (i, sid, p, ts) + for i, (sid, p, ts) in enumerate( + zip(view_session_ids, paths, view_ts, strict=False), start=1 + ) + ], + ) + + baseline_sql = ( + "SELECT DISTINCT session_id, path, COUNT(*) AS hits " + "FROM pageviews GROUP BY session_id, path " + "ORDER BY session_id, path" + ) + gt_sql_predrift = ( + "SELECT session_id, path, COUNT(*) AS hits " + "FROM pageviews GROUP BY session_id, path " + "ORDER BY session_id, path" + ) + + synopsis = ( + "pageviews(id PK, session_id, path, ts_epoch_s). " + "Baseline applies redundant DISTINCT on a GROUP BY whose keys are already unique." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="04_distinct_groupby", + family="saas_logs", + tags=frozenset({"distinct", "redundant_distinct", "group_by", "saas_logs"}), + drift_config=None, + builder=_build, + base_scale=500, +) diff --git a/scenarios/scenario_05_nested_subquery.py b/scenarios/scenario_05_nested_subquery.py new file mode 100644 index 0000000000000000000000000000000000000000..72e394f0011c91ea61181f0c05b9c282041bc5ee --- /dev/null +++ b/scenarios/scenario_05_nested_subquery.py @@ -0,0 +1,100 @@ +"""Scenario 05 — deeply nested subquery anti-pattern. + +Baseline stacks two layers of `IN (SELECT ...)` where a single join + GROUP +BY expresses the same intent. The rewrite flattens the nesting. + +Schema: articles(id, author_id, published_at), comments(id, article_id, author_id), + authors(id, display_name, is_verified). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import categorical_choices, seeded_rng, unique_names, zipfian_choices +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_authors = max(50, scale // 8) + n_articles = scale + n_comments = scale * 4 + + author_ids = list(range(1, n_authors + 1)) + display_names = unique_names(rng, n_authors, prefix="author") + verified = rng.choices([True, False], weights=[0.2, 0.8], k=n_authors) + + article_ids = list(range(1, n_articles + 1)) + article_authors = zipfian_choices(rng, author_ids, n_articles) + article_published = [1_700_000_000 + rng.randrange(120 * 86_400) for _ in range(n_articles)] + article_statuses = categorical_choices( + rng, ["draft", "published", "archived"], n_articles, weights=[0.2, 0.7, 0.1] + ) + + comment_ids = list(range(1, n_comments + 1)) + comment_articles = rng.choices(article_ids, k=n_comments) + comment_authors = zipfian_choices(rng, author_ids, n_comments) + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE authors( id BIGINT PRIMARY KEY, display_name VARCHAR, is_verified BOOLEAN);" + ) + conn.execute( + "CREATE TABLE articles(" + " id BIGINT PRIMARY KEY, author_id BIGINT, published_at_epoch_s BIGINT," + " status VARCHAR);" + ) + conn.execute( + "CREATE TABLE comments( id BIGINT PRIMARY KEY, article_id BIGINT, author_id BIGINT);" + ) + conn.executemany( + "INSERT INTO authors VALUES (?, ?, ?)", + list(zip(author_ids, display_names, verified, strict=False)), + ) + conn.executemany( + "INSERT INTO articles VALUES (?, ?, ?, ?)", + list(zip(article_ids, article_authors, article_published, article_statuses, strict=False)), + ) + conn.executemany( + "INSERT INTO comments VALUES (?, ?, ?)", + list(zip(comment_ids, comment_articles, comment_authors, strict=False)), + ) + + baseline_sql = ( + "SELECT display_name " + "FROM authors " + "WHERE id IN (" + " SELECT author_id FROM comments " + " WHERE article_id IN (" + " SELECT id FROM articles WHERE status = 'published'" + " )" + ") " + "ORDER BY display_name" + ) + gt_sql_predrift = ( + "SELECT DISTINCT a.display_name " + "FROM authors a " + "JOIN comments c ON c.author_id = a.id " + "JOIN articles ar ON ar.id = c.article_id " + "WHERE ar.status = 'published' " + "ORDER BY a.display_name" + ) + + synopsis = ( + "authors(id PK, display_name, is_verified); " + "articles(id PK, author_id→authors.id, published_at_epoch_s, status); " + "comments(id PK, article_id→articles.id, author_id→authors.id). " + "Baseline chains two IN-subqueries where one JOIN + DISTINCT suffices." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="05_nested_subquery", + family="cms", + tags=frozenset({"nested_subquery", "in_subquery", "cms"}), + drift_config=None, + builder=_build, + base_scale=1_500, +) diff --git a/scenarios/scenario_06_having_as_where.py b/scenarios/scenario_06_having_as_where.py new file mode 100644 index 0000000000000000000000000000000000000000..c6251c944ea0b570665d342fd0787cc5e4d58dda --- /dev/null +++ b/scenarios/scenario_06_having_as_where.py @@ -0,0 +1,92 @@ +"""Scenario 06 — HAVING used as WHERE. + +Baseline filters on a grouping key inside HAVING, which forces the engine +to group first and filter after. The rewrite moves the non-aggregate +predicate into WHERE. + +Schema: orders(id, tenant_id, user_id, amount_cents, status, created_at_epoch_s). +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import categorical_choices, lognormal_amounts, seeded_rng, zipfian_choices +from .base import BuilderResult, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_tenants = max(40, scale // 40) + n_users = max(200, scale // 4) + n_orders = scale * 10 + + tenant_ids = list(range(1, n_tenants + 1)) + user_ids = list(range(1, n_users + 1)) + user_tenants = rng.choices(tenant_ids, k=n_users) + + order_user = zipfian_choices(rng, user_ids, n_orders) + order_tenant = [user_tenants[u - 1] for u in order_user] + order_amount = [int(x * 100) for x in lognormal_amounts(rng, n_orders, mu=3.5, sigma=0.8)] + statuses = categorical_choices( + rng, + ["placed", "fulfilled", "refunded", "cancelled"], + n_orders, + weights=[0.55, 0.3, 0.1, 0.05], + ) + created = [1_700_000_000 + rng.randrange(60 * 86_400) for _ in range(n_orders)] + + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE orders(" + " id BIGINT PRIMARY KEY, tenant_id BIGINT, user_id BIGINT," + " amount_cents BIGINT, status VARCHAR, created_at_epoch_s BIGINT);" + ) + conn.executemany( + "INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)", + [ + (i, t, u, a, s, ts) + for i, (t, u, a, s, ts) in enumerate( + zip(order_tenant, order_user, order_amount, statuses, created, strict=False), + start=1, + ) + ], + ) + + # Anti-pattern: filter on a row-level column (status) that IS in GROUP BY + # inside HAVING, forcing the engine to build groups for every status value + # before discarding most of them. The rewrite moves the filter into WHERE + # so aggregation only runs over rows we actually care about. + baseline_sql = ( + "SELECT tenant_id, status, SUM(amount_cents) AS total_cents " + "FROM orders " + "GROUP BY tenant_id, status " + "HAVING status = 'fulfilled' " + " AND SUM(amount_cents) >= 100000 " + "ORDER BY tenant_id" + ) + gt_sql_predrift = ( + "SELECT tenant_id, status, SUM(amount_cents) AS total_cents " + "FROM orders " + "WHERE status = 'fulfilled' " + "GROUP BY tenant_id, status " + "HAVING SUM(amount_cents) >= 100000 " + "ORDER BY tenant_id" + ) + + synopsis = ( + "orders(id PK, tenant_id, user_id, amount_cents, status, created_at_epoch_s). " + "Baseline filters `status` inside HAVING, forcing aggregation over " + "every status group before discarding all but 'fulfilled'." + ) + return conn, baseline_sql, gt_sql_predrift, None, synopsis, frozenset(), frozenset() + + +SPEC = ScenarioSpec( + scenario_id="06_having_as_where", + family="ecommerce", + tags=frozenset({"having_as_where", "aggregate_filter", "ecommerce"}), + drift_config=None, + builder=_build, + base_scale=1_500, +) diff --git a/scenarios/scenario_07_drift_column_rename.py b/scenarios/scenario_07_drift_column_rename.py new file mode 100644 index 0000000000000000000000000000000000000000..ab64b7648807407d387d62607ae9496bab834846 --- /dev/null +++ b/scenarios/scenario_07_drift_column_rename.py @@ -0,0 +1,73 @@ +"""Scenario 07 — drift: column rename (orders.user_id → orders.account_id). + +Baseline groups orders by `user_id`. When the drift fires, the column is +renamed; the agent must rewrite its query against `account_id`. Row +semantics are unchanged — the post-drift hash equals the pre-drift hash +because the only thing that changed is the column label. +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import lognormal_amounts, seeded_rng, zipfian_choices +from .base import BuilderResult, DriftConfig, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_users = max(200, scale // 2) + n_orders = scale * 4 + + user_ids = list(range(1, n_users + 1)) + order_users = zipfian_choices(rng, user_ids, n_orders) + amounts = lognormal_amounts(rng, n_orders, mu=3.0, sigma=0.8) + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE users(id BIGINT PRIMARY KEY, signup_month INTEGER);") + conn.execute("CREATE TABLE orders(id BIGINT PRIMARY KEY, user_id BIGINT, amount DOUBLE);") + conn.executemany( + "INSERT INTO users VALUES (?, ?)", + [(uid, (uid % 12) + 1) for uid in user_ids], + ) + conn.executemany( + "INSERT INTO orders VALUES (?, ?, ?)", + [(i, uid, amt) for i, (uid, amt) in enumerate(zip(order_users, amounts, strict=False), 1)], + ) + + baseline_sql = ( + "SELECT user_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total " + "FROM orders GROUP BY user_id ORDER BY user_id" + ) + gt_sql_predrift = baseline_sql # static part — baseline IS correct pre-drift + gt_sql_postdrift = ( + "SELECT account_id, COUNT(*) AS n_orders, ROUND(SUM(amount), 2) AS total " + "FROM orders GROUP BY account_id ORDER BY account_id" + ) + + synopsis = ( + "users(id PK, signup_month); orders(id PK, user_id→users.id, amount). " + "Under drift, orders.user_id is renamed to orders.account_id." + ) + return ( + conn, + baseline_sql, + gt_sql_predrift, + gt_sql_postdrift, + synopsis, + frozenset({"account_id"}), + frozenset({"user_id"}), + ) + + +SPEC = ScenarioSpec( + scenario_id="07_drift_column_rename", + family="ecommerce", + tags=frozenset({"drift", "column_rename", "ecommerce"}), + drift_config=DriftConfig( + kind="column_rename", + payload={"table": "orders", "old": "user_id", "new": "account_id"}, + ), + builder=_build, + base_scale=2_000, +) diff --git a/scenarios/scenario_08_drift_date_format.py b/scenarios/scenario_08_drift_date_format.py new file mode 100644 index 0000000000000000000000000000000000000000..07a27075f23e271edaed4e73a16c32ede9804beb --- /dev/null +++ b/scenarios/scenario_08_drift_date_format.py @@ -0,0 +1,93 @@ +"""Scenario 08 — drift: date format (events.ts iso_string → epoch_ms). + +Baseline filters events inside a specific UTC day via ISO string +comparisons. When the drift fires, the `ts` column becomes a BIGINT of +epoch-ms; the agent must rewrite comparisons against the numeric value. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import duckdb + +from ._fixtures import categorical_choices, seeded_rng +from .base import BuilderResult, DriftConfig, ScenarioSpec + +# Fixed anchor day — comparisons are deterministic across seeds. +_ANCHOR = datetime(2026, 4, 21, tzinfo=UTC) +_ANCHOR_NEXT = _ANCHOR + timedelta(days=1) +_ANCHOR_ISO = _ANCHOR.isoformat().replace("+00:00", "Z") +_ANCHOR_NEXT_ISO = _ANCHOR_NEXT.isoformat().replace("+00:00", "Z") +_ANCHOR_MS = int(_ANCHOR.timestamp() * 1000) +_ANCHOR_NEXT_MS = int(_ANCHOR_NEXT.timestamp() * 1000) + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_events = scale * 8 + + window_start = _ANCHOR - timedelta(days=3) + window_span_s = 7 * 86_400 + event_dts = [ + (window_start + timedelta(seconds=rng.randrange(window_span_s))) for _ in range(n_events) + ] + event_iso = [dt.isoformat().replace("+00:00", "Z") for dt in event_dts] + kinds = categorical_choices( + rng, ["login", "action", "error"], n_events, weights=[0.6, 0.35, 0.05] + ) + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE events( id BIGINT PRIMARY KEY, ts VARCHAR, kind VARCHAR);") + conn.executemany( + "INSERT INTO events VALUES (?, ?, ?)", + [(i, iso, k) for i, (iso, k) in enumerate(zip(event_iso, kinds, strict=False), 1)], + ) + + baseline_sql = ( + f"SELECT kind, COUNT(*) AS n FROM events " + f"WHERE ts >= '{_ANCHOR_ISO}' AND ts < '{_ANCHOR_NEXT_ISO}' " + f"GROUP BY kind ORDER BY kind" + ) + gt_sql_predrift = baseline_sql + gt_sql_postdrift = ( + f"SELECT kind, COUNT(*) AS n FROM events " + f"WHERE ts >= {_ANCHOR_MS} AND ts < {_ANCHOR_NEXT_MS} " + f"GROUP BY kind ORDER BY kind" + ) + + synopsis = ( + "events(id PK, ts VARCHAR(ISO-8601 UTC), kind). Under drift, `ts` " + f"becomes BIGINT epoch-ms. Filter window is {_ANCHOR_ISO} – " + f"{_ANCHOR_NEXT_ISO} (i.e. epoch-ms in " + f"[{_ANCHOR_MS}, {_ANCHOR_NEXT_MS}))." + ) + # Date-format drift keeps the ``ts`` identifier; what changes is the + # literal shape (ISO string → epoch-ms integer). The rubric + # therefore can't distinguish "adapted" from "not adapted" on + # identifiers alone, so we expose the ISO anchor strings as the + # pre-drift distinctive set and leave postdrift empty — the rubric + # treats absence-of-predrift-markers as adaptation whenever + # ``postdrift_identifiers`` is empty (DriftAdapt rubric case). + return ( + conn, + baseline_sql, + gt_sql_predrift, + gt_sql_postdrift, + synopsis, + frozenset(), + frozenset({_ANCHOR_ISO, _ANCHOR_NEXT_ISO}), + ) + + +SPEC = ScenarioSpec( + scenario_id="08_drift_date_format", + family="events", + tags=frozenset({"drift", "date_format", "iso_to_epoch", "events"}), + drift_config=DriftConfig( + kind="date_format", + payload={"table": "events", "col": "ts", "from": "iso_string", "to": "epoch_ms"}, + ), + builder=_build, + base_scale=500, +) diff --git a/scenarios/scenario_09_drift_enum_rule.py b/scenarios/scenario_09_drift_enum_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..81126a5262fa0a0dfa6a11db839aea48983f0b37 --- /dev/null +++ b/scenarios/scenario_09_drift_enum_rule.py @@ -0,0 +1,85 @@ +"""Scenario 09 — drift: enum split ('active' → 'ACTIVE' / 'ACTIVE_V2'). + +Baseline counts tenants whose status is 'active'. After drift, every +previously-'active' row has been relabelled to either 'ACTIVE' or +'ACTIVE_V2' (deterministic round-robin). The agent must filter on the +union of the new values to recover the business-equivalent count. + +Note: unlike 07 and 10, the post-drift data changed, so the post-drift +ground-truth hash is computed against the post-drift rows — the agent's +result set now reflects the new status values. +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import categorical_choices, seeded_rng +from .base import BuilderResult, DriftConfig, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_tenants = max(400, scale) + + statuses = categorical_choices( + rng, + ["active", "trial", "suspended", "churned"], + n_tenants, + weights=[0.55, 0.2, 0.15, 0.1], + ) + tiers = categorical_choices( + rng, ["free", "pro", "business"], n_tenants, weights=[0.6, 0.3, 0.1] + ) + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE tenants( id BIGINT PRIMARY KEY, status VARCHAR, tier VARCHAR);") + conn.executemany( + "INSERT INTO tenants VALUES (?, ?, ?)", + [(i, s, t) for i, (s, t) in enumerate(zip(statuses, tiers, strict=False), 1)], + ) + + # Pre-drift: all 'active' tenants; post-drift: union of the new labels. + baseline_sql = ( + "SELECT tier, COUNT(*) AS n " + "FROM tenants WHERE status = 'active' " + "GROUP BY tier ORDER BY tier" + ) + gt_sql_predrift = baseline_sql + gt_sql_postdrift = ( + "SELECT tier, COUNT(*) AS n " + "FROM tenants WHERE status IN ('ACTIVE', 'ACTIVE_V2') " + "GROUP BY tier ORDER BY tier" + ) + + synopsis = ( + "tenants(id PK, status, tier). Under drift, status='active' is split " + "into 'ACTIVE' and 'ACTIVE_V2'; 'trial'/'suspended'/'churned' are unchanged." + ) + return ( + conn, + baseline_sql, + gt_sql_predrift, + gt_sql_postdrift, + synopsis, + frozenset({"ACTIVE", "ACTIVE_V2"}), + frozenset({"active"}), + ) + + +SPEC = ScenarioSpec( + scenario_id="09_drift_enum_rule", + family="multitenant", + tags=frozenset({"drift", "enum_rule", "business_rule", "multitenant"}), + drift_config=DriftConfig( + kind="enum_rule", + payload={ + "table": "tenants", + "col": "status", + "old_value": "active", + "new_values": ["ACTIVE", "ACTIVE_V2"], + }, + ), + builder=_build, + base_scale=600, +) diff --git a/scenarios/scenario_10_drift_field_deprecation.py b/scenarios/scenario_10_drift_field_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cb375f2379e52b1df4099789ff60b7babf5957 --- /dev/null +++ b/scenarios/scenario_10_drift_field_deprecation.py @@ -0,0 +1,74 @@ +"""Scenario 10 — drift: inline string col replaced by FK lookup. + +Baseline groups posts by the inline `author_name` string. Under drift, a +`users(id, full_name)` lookup is created, `posts.author_name` is dropped, +and `posts.users_id` is added with a backfilled FK. The correct rewrite +joins through `users` and groups by `full_name`. +""" + +from __future__ import annotations + +import duckdb + +from ._fixtures import seeded_rng, unique_names, zipfian_choices +from .base import BuilderResult, DriftConfig, ScenarioSpec + + +def _build(spec: ScenarioSpec, seed: int, scale: int) -> BuilderResult: + rng = seeded_rng(spec.scenario_id, seed, scale) + n_authors = max(40, scale // 10) + n_posts = scale * 4 + + author_names = unique_names(rng, n_authors, prefix="author") + post_author_idx = zipfian_choices(rng, list(range(n_authors)), n_posts) + post_author_names = [author_names[i] for i in post_author_idx] + + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE posts( id BIGINT PRIMARY KEY, author_name VARCHAR, body VARCHAR);") + conn.executemany( + "INSERT INTO posts VALUES (?, ?, ?)", + [(i, n, f"body {i}") for i, n in enumerate(post_author_names, 1)], + ) + + baseline_sql = ( + "SELECT author_name, COUNT(*) AS n_posts " + "FROM posts GROUP BY author_name " + "ORDER BY author_name" + ) + gt_sql_predrift = baseline_sql + gt_sql_postdrift = ( + "SELECT u.full_name AS author_name, COUNT(*) AS n_posts " + "FROM posts p JOIN users u ON u.id = p.users_id " + "GROUP BY u.full_name ORDER BY u.full_name" + ) + + synopsis = ( + "posts(id PK, author_name, body). Under drift, posts.author_name is " + "deprecated; a new users(id PK, full_name) table is created and " + "posts gains a users_id FK. Rewrites must JOIN through users." + ) + return ( + conn, + baseline_sql, + gt_sql_predrift, + gt_sql_postdrift, + synopsis, + frozenset({"users", "users_id", "full_name"}), + frozenset({"author_name"}), + ) + + +SPEC = ScenarioSpec( + scenario_id="10_drift_field_deprecation", + family="cms", + tags=frozenset({"drift", "field_deprecation", "fk_backfill", "cms"}), + drift_config=DriftConfig( + kind="field_deprecation", + payload={ + "orig": ("posts", "author_name"), + "lookup": ("users", "id", "full_name"), + }, + ), + builder=_build, + base_scale=1_500, +) diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61f6222732bd858ad96b6491d127e5fc4e3c2b88 --- /dev/null +++ b/server/__init__.py @@ -0,0 +1,5 @@ +"""SQLDrift server package.""" + +from server.sql_drift_env_environment import SqlDriftEnvironment + +__all__ = ["SqlDriftEnvironment"] diff --git a/server/app.py b/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..053e5290542a3375ec6460ff8b9f430dbc632d83 --- /dev/null +++ b/server/app.py @@ -0,0 +1,89 @@ +"""FastAPI factory for the SQLDrift environment. + +``create_app()`` returns a fully-wired FastAPI app exposing the +stateless HTTP routes (``/health``, ``/schema``, ``/reset``, ``/step``) +and the stateful ``/ws`` WebSocket session. Stateful multi-step +episodes must go through ``/ws``; each HTTP ``/step`` spawns a +fresh env instance that is ``close()``-d in ``finally`` (one env per request). + +``main()`` runs the server with Uvicorn — exported as the +``[project.scripts] sql-drift-server`` entry point. +""" + +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from openenv.core.env_server.http_server import create_app as _openenv_create_app + +from models import SqlDriftAction, SqlDriftObservation +from skill_library import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs + +from . import settings +from .sql_drift_env_environment import SqlDriftEnvironment + +ENV_NAME = "sql_drift_env" +DEFAULT_MAX_CONCURRENT_ENVS = settings.MAX_CONCURRENT_ENVS +_SESSION_STORE_ROOT = DEFAULT_STORE_DIR / "sessions" + +# Purge stale session directories left by previous server runs before +# accepting any traffic. Failures are non-fatal. +_startup_removed = cleanup_stale_session_dirs( + _SESSION_STORE_ROOT, settings.SKILL_STORE_SESSION_TTL_HOURS +) +if _startup_removed: + import logging as _logging + + _logging.getLogger("sql_drift_env.app.server.app").info( + "startup: removed %d stale session skill-store dirs from %s", + _startup_removed, + _SESSION_STORE_ROOT, + ) + + +def _create_server_environment() -> SqlDriftEnvironment: + """Build one server-managed env with its own on-disk skill library. + + ``cleanup_on_close=True`` ensures the session directory is deleted when + the WebSocket session ends, preventing unbounded on-disk session growth. + """ + session_dir = _SESSION_STORE_ROOT / uuid4().hex + return SqlDriftEnvironment( + skill_store=Store(directory=session_dir), + cleanup_on_close=True, + ) + + +def create_app(max_concurrent_envs: int | None = None) -> Any: + """Build the FastAPI app bound to a fresh-env factory per session.""" + if max_concurrent_envs is None: + max_concurrent_envs = DEFAULT_MAX_CONCURRENT_ENVS + return _openenv_create_app( + env=_create_server_environment, + action_cls=SqlDriftAction, + observation_cls=SqlDriftObservation, + env_name=ENV_NAME, + max_concurrent_envs=max_concurrent_envs, + ) + + +def main(host: str = settings.SERVER_HOST, port: int = settings.SERVER_PORT) -> None: + """Uvicorn entry point — matches the [project.scripts] wiring.""" + import uvicorn + + uvicorn.run(create_app(), host=host, port=port) + + +# Module-level app instance for uvicorn's ``module:attr`` syntax +# (``uvicorn server.app:app``) and the ``openenv.yaml`` ``app:`` field. +# Built at import time; safe because the OpenEnv factory only stores the +# environment factory and instantiates per request / session. +app = create_app() + + +__all__ = ["ENV_NAME", "app", "create_app", "main"] + + +if __name__ == "__main__": + main() diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..562054df11dc39901cb8fbd92621e73f37b80040 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,116 @@ +# This file was autogenerated by uv via the following command: +# uv export --format requirements-txt --no-dev --no-hashes --no-emit-project --no-annotate -o server/requirements.txt +aiofile==3.9.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.13.0 +attrs==26.1.0 +audioop-lts==0.2.2 ; python_full_version >= '3.13' +authlib==1.7.0 +beartype==0.22.9 +brotli==1.2.0 +cachetools==7.0.6 +caio==0.9.25 +certifi==2026.2.25 +cffi==2.0.0 ; platform_python_implementation != 'PyPy' +charset-normalizer==3.4.7 +click==8.3.2 +colorama==0.4.6 ; sys_platform == 'win32' +cryptography==46.0.7 +cyclopts==4.10.2 +distro==1.9.0 +dnspython==2.8.0 +docstring-parser==0.18.0 +docutils==0.22.4 +duckdb==1.5.2 +email-validator==2.3.0 +exceptiongroup==1.3.1 +fastapi==0.136.0 +fastmcp==3.2.4 +filelock==3.29.0 +fsspec==2026.2.0 +gradio==6.13.0 +gradio-client==2.5.0 +griffelib==2.0.2 +groovy==0.1.2 +h11==0.16.0 +hf-gradio==0.4.1 +hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +httpcore==1.0.9 +httpx==0.28.1 +httpx-sse==0.4.3 +huggingface-hub==1.11.0 +idna==3.11 +importlib-metadata==8.7.1 +jaraco-classes==3.4.0 +jaraco-context==6.1.2 +jaraco-functools==4.4.0 +jeepney==0.9.0 ; sys_platform == 'linux' +jinja2==3.1.6 +jiter==0.14.0 +joserfc==1.6.4 +jsonref==1.1.0 +jsonschema==4.26.0 +jsonschema-path==0.4.5 +jsonschema-specifications==2025.9.1 +keyring==25.7.0 +markdown-it-py==4.0.0 +markupsafe==3.0.3 +mcp==1.27.0 +mdurl==0.1.2 +more-itertools==11.0.2 +numpy==2.4.4 +openai==2.32.0 +openapi-pydantic==0.5.1 +openenv-core==0.2.3 +opentelemetry-api==1.41.0 +orjson==3.11.8 +packaging==26.1 +pandas==3.0.2 +pathable==0.5.0 +pillow==12.2.0 +platformdirs==4.9.6 +py-key-value-aio==0.4.4 +pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy' +pydantic==2.13.3 +pydantic-core==2.46.3 +pydantic-settings==2.14.0 +pydub==0.25.1 +pygments==2.20.0 +pyjwt==2.12.1 +pyperclip==1.11.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.2.2 +python-multipart==0.0.26 +pytz==2026.1.post1 +pywin32==311 ; sys_platform == 'win32' +pywin32-ctypes==0.2.3 ; sys_platform == 'win32' +pyyaml==6.0.3 +referencing==0.37.0 +requests==2.33.1 +rich==15.0.0 +rich-rst==1.3.2 +rpds-py==0.30.0 +safehttpx==0.1.7 +secretstorage==3.5.0 ; sys_platform == 'linux' +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +sqlglot==30.6.0 +sse-starlette==3.3.4 +starlette==1.0.0 +tomli==2.4.1 +tomli-w==1.2.0 +tomlkit==0.14.0 +tqdm==4.67.3 +typer==0.24.1 +typing-extensions==4.15.0 +typing-inspection==0.4.2 +tzdata==2026.1 ; sys_platform == 'emscripten' or sys_platform == 'win32' +uncalled-for==0.3.1 +urllib3==2.6.3 +uvicorn==0.44.0 +watchfiles==1.1.1 +websockets==16.0 +zipp==3.23.1 diff --git a/server/settings.py b/server/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..7c63a9c217a38dbc4e80e08ff901d072bd3bddec --- /dev/null +++ b/server/settings.py @@ -0,0 +1,31 @@ +"""Server/runtime settings loaded from the repo-local `.env`.""" + +from __future__ import annotations + +from utilities.env_loader import env_float, env_int, env_str + +SERVER_HOST = env_str("SQL_DRIFT_SERVER_HOST", "0.0.0.0") +SERVER_PORT = env_int("SQL_DRIFT_SERVER_PORT", 8000, min_value=1, max_value=65535) +MAX_CONCURRENT_ENVS = env_int("SQL_DRIFT_MAX_CONCURRENT_ENVS", 4, min_value=1) + +DEFAULT_STEP_BUDGET = env_int("SQL_DRIFT_DEFAULT_STEP_BUDGET", 25, min_value=1) +MAX_RESULT_ROWS = env_int("SQL_DRIFT_MAX_RESULT_ROWS", 1_000, min_value=1) +QUERY_TIMEOUT_S = env_float("SQL_DRIFT_QUERY_TIMEOUT_S", 2.0, min_value=0.001) + +# Session skill-store directories older than this many hours are removed at +# startup and when the owning environment is closed. Set to 0 to disable +# TTL-based cleanup (directories will still be removed on close when +# cleanup_on_close=True is set for a server-managed environment). +SKILL_STORE_SESSION_TTL_HOURS = env_float( + "SQL_DRIFT_SKILL_STORE_SESSION_TTL_HOURS", 24.0, min_value=0.0 +) + +__all__ = [ + "DEFAULT_STEP_BUDGET", + "MAX_CONCURRENT_ENVS", + "MAX_RESULT_ROWS", + "QUERY_TIMEOUT_S", + "SERVER_HOST", + "SERVER_PORT", + "SKILL_STORE_SESSION_TTL_HOURS", +] diff --git a/server/sql_drift_env_environment.py b/server/sql_drift_env_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6bed5391acc880038581ffcaf26814e8067c7d --- /dev/null +++ b/server/sql_drift_env_environment.py @@ -0,0 +1,899 @@ +"""OpenEnv ``Environment`` implementation for SQLDrift. + +Responsibilities: + +* Own the private :class:`engine.runtime.RuntimeEpisodeState` and the + composite :class:`engine.reward.SqlDriftRubric` for the current episode. +* Dispatch each of the eight tool-call payloads to a dedicated + ``_handle_`` method that returns a typed + :class:`models.ToolResult` (or :class:`models.ToolError`). +* Fire drift on a schedule blended with a cooldown: ``max(scheduled, + first_run_query_step + cooldown)`` before the agent acts on the step + where drift applies, then recompute the post-drift ground truth hash. +* Publish public observations (:class:`models.SqlDriftObservation`) and a + strictly sanitised public state snapshot (:class:`models.SqlDriftState`). + +Privacy: ``self._runtime`` holds the DuckDB handle, ground-truth hashes, +baseline runtime, and seed. They stay inside this class; the rubric reads +them via a closure, and ``env.state`` exposes only a fixed whitelist of fields. +""" + +from __future__ import annotations + +import contextlib +import math +import re +import secrets +from random import Random +from typing import TYPE_CHECKING, Any, Literal + +import duckdb +import sqlglot +from openenv.core.env_server.interfaces import Environment +from pydantic import BaseModel, ConfigDict, Field + +from actors import dba_oracle +from actors.engineering_manager import author_changelog +from engine.drift import apply_drift +from engine.profiler import ( + QueryWatchdogEscalationError, + execute_hash_timed, + execute_once_timed, + execute_once_with_columns, +) +from engine.reward import ( + SPEEDUP_CAP_FOR_INFTY, + STEP_REBATE_DESCRIBE_TABLE, + STEP_REBATE_EXPLAIN_QUERY, + STEP_REBATE_LIST_TABLES, + STEP_REBATE_READ_CHANGELOG, + STEP_REBATE_RUN_QUERY, + STEP_REBATE_SAMPLE_ROWS, + SqlDriftRubric, + canonicalize_sql, + effective_speedup, +) +from engine.runtime import RuntimeEpisodeState +from engine.verifier import canonical_row_hash +from models import ( + REWARD_COMPONENT_KEYS, + ConsultDBAPayload, + ConsultDBAResult, + DescribeTablePayload, + DescribeTableResult, + EpisodePhase, + ExplainQueryPayload, + ExplainQueryResult, + ListTablesPayload, + ListTablesResult, + ReadChangelogPayload, + ReadChangelogResult, + RunQueryPayload, + RunQueryResult, + SampleRowsPayload, + SampleRowsResult, + SqlDriftAction, + SqlDriftObservation, + SqlDriftState, + SubmitRewritePayload, + SubmitRewriteResult, + ToolError, + ToolErrorCode, + ToolResult, +) +from scenarios import REGISTRY, get_spec +from skill_library import PlaybookEntry, Store, load_all, retrieve +from utilities.logger import get_module_logger, log_env_reset, log_env_step, log_interaction + +from . import settings + +if TYPE_CHECKING: + from scenarios.base import ScenarioSpec + +_LOG = get_module_logger(__name__) + +DEFAULT_STEP_BUDGET: int = settings.DEFAULT_STEP_BUDGET +MAX_RESULT_ROWS: int = settings.MAX_RESULT_ROWS +QUERY_TIMEOUT_S: float = settings.QUERY_TIMEOUT_S + + +class _ResetOptions(BaseModel): + model_config = ConfigDict(extra="ignore") + + scenario_id: str | None = None + enable_dba_oracle: bool | None = None + difficulty: Literal["easy", "normal", "hard"] = "normal" + budget_steps: int = Field(default=DEFAULT_STEP_BUDGET, ge=1) + + +_READ_ONLY_EXPRESSION_KEYS: frozenset[str] = frozenset({"select", "with"}) + +# DuckDB exposes a family of table-valued functions and scalar helpers +# that read from the host filesystem or leak introspection state — +# ``read_csv``, ``read_parquet``, ``read_json``, ``read_text``, +# ``parquet_metadata``, ``duckdb_secrets``, ``glob``, etc. They are +# *technically* SELECT-shaped calls so the statement-key check alone +# admits them. We reject any function whose lowercased name starts with +# one of these prefixes or exactly matches one of the known-dangerous +# standalone names. Agent-facing SQL has no legitimate need for any of +# them — the DuckDB connection is pre-populated by the scenario builder. +_DENYLIST_PREFIXES: tuple[str, ...] = ( + "read_", + "write_", + "copy_", + "duckdb_", + "pragma_", + "sniff_", + "parquet_", + "arrow_", + "json_table", + "json_each", + "sqlite_", + "load_", + "install_", +) +_DENYLIST_EXACT: frozenset[str] = frozenset( + { + "glob", + "attach", + "detach", + "checkpoint", + "force_checkpoint", + "set_secret", + "create_secret", + "drop_secret", + "enable_profiling", + "disable_profiling", + "enable_object_cache", + } +) + + +def _is_denylisted_function_name(name: str) -> bool: + """Return True iff ``name`` (case-insensitively) matches a sandbox-escape.""" + lowered = name.lower() + if lowered in _DENYLIST_EXACT: + return True + return any(lowered.startswith(p) for p in _DENYLIST_PREFIXES) + + +def _function_names(node: sqlglot.exp.Func) -> list[str]: + """All plausible names to check against the denylist for one AST node. + + sqlglot lowers a few DuckDB calls into dedicated expression classes + (``ReadCSV``, ``ReadParquet``, …) whose ``.name`` is actually the + first positional arg — the file path — not the function name. We + recover the function name from the class name in that case and fall + back to ``.name`` for the ``Anonymous`` form that covers everything + else. Including both lets one denylist lookup cover both lowerings. + """ + cls = type(node).__name__ + out: list[str] = [] + # Derive a snake-case function name from the class name. We insert + # an underscore at two kinds of CamelCase boundaries: + # + # * ``aB`` — normal lower-to-upper (``ReadParquet`` → ``read_parquet``) + # * ``ABc`` — end of an acronym run (``ReadCSVAuto`` → ``read_csv_auto``) + # + # Purely-lowercase class names (``Anonymous``) produce no prefix + # match; we fall through to ``.name`` below for those. + if cls and cls[0].isupper(): + snake = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", cls).lower() + out.append(snake) + name_attr = getattr(node, "name", None) + if isinstance(name_attr, str) and name_attr: + out.append(name_attr) + return out + + +_VALID_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _resolve_timeout_s(timeout_s: float | None) -> float: + """Caller-supplied per-step timeout or the module default. + + ``timeout_s`` is accepted on every OpenEnv ``step()`` (the abstract + base mandates the keyword). When the caller provides a positive + value we honour it as the wall-clock budget for any DuckDB query + this step runs; ``None`` and non-positive values fall back to the + module-level :data:`QUERY_TIMEOUT_S` so a mis-configured client + cannot silently disable the watchdog. + """ + if timeout_s is None or timeout_s <= 0: + return QUERY_TIMEOUT_S + return float(timeout_s) + + +def _initial_schema_synopsis(spec: ScenarioSpec, synopsis: str) -> str: + """Reset-time synopsis with future drift details removed. + + Drift scenarios should not reveal the exact schema/business-rule + change before the changelog is published at runtime. We therefore + trim the authored synopsis at the first ``" Under drift"`` clause on + reset and only surface the pre-drift schema shape. + """ + if spec.drift_config is None: + return synopsis + predrift, marker, _ = synopsis.partition(" Under drift") + return predrift if marker else synopsis + + +def _validate_read_only_sql(sql: str) -> None: + """Reject anything that isn't a single-statement read-only SELECT/CTE. + + Raises ``ValueError`` so the caller can translate to a typed + :class:`models.ToolError` with :attr:`ToolErrorCode.INVALID_TOOL_ARGUMENT`. + This is the only place that mediates what the policy may execute; + scenario builders and drift DDL call DuckDB directly with privileged + SQL and deliberately bypass this check. + + Beyond the statement-level gate, this walker also rejects two + sandbox-escape vectors that would otherwise ride along inside a + perfectly-shaped SELECT: + + 1. Table-valued functions that read from the host filesystem + (``read_csv``, ``read_parquet``, ``read_json_auto``, ``glob``, + ``read_text``, …) or leak engine introspection (``duckdb_secrets`` + carries credentials; ``duckdb_settings`` /``duckdb_functions`` + can enumerate available exploits). See :data:`_DENYLIST_PREFIXES` + / :data:`_DENYLIST_EXACT`. + 2. ``SELECT * FROM 'path/to/x.csv'`` — DuckDB treats a bare string + literal in a FROM clause as a filesystem path and auto-detects + the format. There is no function node to inspect in this form, + so we separately reject any :class:`sqlglot.exp.Table` whose + backing expression is a string literal. + """ + try: + statements = sqlglot.parse(sql, dialect="duckdb") + except sqlglot.errors.ParseError as exc: + raise ValueError(f"SQL failed to parse: {exc}") from exc + + non_empty = [s for s in statements if s is not None] + if len(non_empty) != 1: + raise ValueError("multi-statement SQL is not allowed; submit one SELECT") + expr = non_empty[0] + if expr.key not in _READ_ONLY_EXPRESSION_KEYS: + raise ValueError( + f"only read-only SELECT/CTE queries are allowed (got {expr.key.upper()} statement)" + ) + + for node in expr.walk(): + # (1) Function-valued sandbox escapes. Inspect both the class + # name (catches ``ReadCSV`` / ``ReadParquet`` lowerings where + # ``.name`` holds the file path, not the function name) and + # ``.name`` (catches the generic ``Anonymous`` form). + if isinstance(node, sqlglot.exp.Func): + for fn_name in _function_names(node): + if _is_denylisted_function_name(fn_name): + raise ValueError( + f"function {fn_name!r} is not allowed — agent-facing SQL may " + "only touch the scenario's in-memory tables" + ) + # (2) Bare-path FROM form: ``SELECT * FROM 'x.csv'`` or + # ``SELECT * FROM '/etc/passwd'``. sqlglot normalises both + # single- and double-quoted identifiers to + # ``Identifier(quoted=True)``, so we can't rely on the quote + # flavour to distinguish a file path from a legitimately-quoted + # table name. Instead we require every agent-facing table name + # to be a valid unquoted SQL identifier — the scenarios never + # emit anything else, and paths always contain ``/``, ``.`` or + # ``~`` which fail the identifier regex. + if isinstance(node, sqlglot.exp.Table): + inner = node.this + if isinstance(inner, sqlglot.exp.Identifier): + ident_name = inner.name + if ident_name and not _VALID_IDENTIFIER_RE.match(ident_name): + raise ValueError( + f"table identifier {ident_name!r} is not a valid unquoted SQL " + "name — reading from file paths or other engine-specific " + "resources is not allowed" + ) + + +class SqlDriftEnvironment(Environment[SqlDriftAction, SqlDriftObservation, SqlDriftState]): + """OpenEnv environment for SQL repair + optimization under schema drift.""" + + SUPPORTS_CONCURRENT_SESSIONS = True + + def __init__( + self, + skill_store: Store | None = None, + cleanup_on_close: bool = False, + ) -> None: + self._runtime: RuntimeEpisodeState | None = None + self._skill_store: Store | None = skill_store + # When True, the skill-store directory is deleted when close() is called. + # Set this for server-managed per-session stores so disk usage doesn't grow + # monotonically; see design/codereview.md (session store issue). + self._cleanup_on_close: bool = cleanup_on_close + super().__init__( + rubric=SqlDriftRubric(ctx_provider=lambda: self._require_runtime()), + ) + + # ------------------------------------------------------------------ + # OpenEnv contract + # ------------------------------------------------------------------ + + @log_env_reset + def reset( + self, + seed: int | None = None, + episode_id: str | None = None, + **kwargs: Any, + ) -> SqlDriftObservation: + options = _ResetOptions.model_validate(kwargs) + scenario_id = options.scenario_id + enable_dba_oracle = dba_oracle.is_enabled(options.enable_dba_oracle) + difficulty = options.difficulty + budget_steps = options.budget_steps + + if seed is None: + seed = secrets.randbits(31) + if episode_id is None: + episode_id = f"ep-{seed:08x}" + if scenario_id is None: + scenario_id = self._pick_scenario_for_seed(seed) + + spec = get_spec(scenario_id) + instance = spec.materialize(seed, difficulty=difficulty) + + drift_scheduled_step: int | None = None + if instance.drift_config is not None: + drift_scheduled_step = Random(seed).randint( + instance.drift_config.min_step, + instance.drift_config.max_step, + ) + + self._close_existing_runtime() + self._runtime = RuntimeEpisodeState( + episode_id=episode_id, + seed=seed, + scenario_id=scenario_id, + instance=instance, + conn=instance.conn, + gt_result_hash_predrift=instance.gt_result_hash_predrift, + gt_result_hash_postdrift=None, + baseline_runtime_ms=instance.baseline_runtime_ms, + baseline_tokens=instance.baseline_tokens, + baseline_sql_canonical=canonicalize_sql(instance.baseline_sql), + baseline_postdrift_raises=False, + drift_scheduled_step=drift_scheduled_step, + budget_steps=budget_steps, + dba_oracle_enabled=enable_dba_oracle, + ) + + self._reset_rubric() + + learned_hints = kwargs.get("learned_hints") + if learned_hints is None: + learned_hints = self._render_learned_hints(spec, include_drift_cards=False) + if len(learned_hints) > 800: + learned_hints = learned_hints[:800] + + rt = self._require_runtime() + return SqlDriftObservation( + step=0, + phase=EpisodePhase.DIAGNOSE, + last_tool=None, + tool_result=None, + drift_fired=False, + drift_acknowledged=False, + learned_hints=learned_hints, + baseline_sql=instance.baseline_sql, + schema_synopsis=_initial_schema_synopsis(spec, instance.schema_synopsis), + budget_steps_remaining=rt.budget_steps_remaining, + reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS}, + done=False, + reward=None, + ) + + @log_env_step + def step( + self, + action: SqlDriftAction, + timeout_s: float | None = None, + **kwargs: Any, + ) -> SqlDriftObservation: + rt = self._require_runtime() + if rt.submitted or rt.budget_steps_remaining <= 0: + raise ValueError("Episode is already finished; call reset() to start a new episode.") + rt.step_count += 1 + rt.last_step_was_tool_error = False + rt.last_step_was_repeat_failing_query = False + rt.last_step_repeat_failing_query_count = 0 + rt.last_step_productive_rebate = 0.0 + + self._maybe_fire_drift() + + effective_timeout_s = _resolve_timeout_s(timeout_s) + try: + tool_result = self._dispatch(action, timeout_s=effective_timeout_s) + except QueryWatchdogEscalationError: + rt.connection_poisoned = True + rt.phase = EpisodePhase.FINALIZE + rt.step_count = max(rt.step_count, rt.budget_steps) + _LOG.error("episode %s aborted after watchdog escalation", rt.episode_id) + raise + rt.last_step_was_tool_error = isinstance(tool_result, ToolError) + if rt.last_step_was_tool_error: + rt.consecutive_tool_errors += 1 + else: + rt.consecutive_tool_errors = 0 + + done = rt.submitted or rt.budget_steps_remaining <= 0 + + obs = SqlDriftObservation( + step=rt.step_count, + phase=rt.phase, + last_tool=action.tool, + tool_result=tool_result, + drift_fired=rt.drift_fired, + drift_acknowledged=rt.drift_acknowledged, + learned_hints="", + baseline_sql="", + schema_synopsis="", + budget_steps_remaining=rt.budget_steps_remaining, + reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS}, + done=done, + reward=None, + ) + if rt.drift_acknowledged: + spec = get_spec(rt.scenario_id) + obs.learned_hints = self._render_learned_hints(spec, include_drift_cards=True) + + obs.reward = self._apply_rubric(action, obs) + if self.rubric is not None: + obs.reward_components = self.rubric.component_scores() + + if done and rt.submitted: + self._maybe_persist_learned_entry() + return obs + + def render(self) -> dict[str, Any]: + """Render the current public state and log the render interaction.""" + rt = self._require_runtime() + state = self.state + payload = state.model_dump(mode="json") + log_interaction( + event_type="render", + agent_id=rt.episode_id, + observation_returned=payload, + done=rt.submitted or rt.budget_steps_remaining <= 0, + ) + return payload + + @property + def state(self) -> SqlDriftState: + """Sanitised public state snapshot (explicit whitelist).""" + rt = self._require_runtime() + return SqlDriftState( + episode_id=rt.episode_id, + step_count=rt.step_count, + scenario_id=rt.scenario_id, + phase=rt.phase, + budget_steps_remaining=rt.budget_steps_remaining, + drift_fired=rt.drift_fired, + consultations_used=rt.consultations_used, + submitted=rt.submitted, + ) + + def effective_speedup(self) -> float | None: + """Return the current episode's effective speedup, if any.""" + rt = self._runtime + if rt is None: + return None + return effective_speedup(rt) + + def close(self) -> None: + self._close_existing_runtime() + if self._cleanup_on_close and self._skill_store is not None: + import shutil + + store_dir = self._skill_store.dir + shutil.rmtree(store_dir, ignore_errors=True) + + # ------------------------------------------------------------------ + # Skill-library wiring + # ------------------------------------------------------------------ + + def _render_learned_hints(self, spec: ScenarioSpec, *, include_drift_cards: bool = True) -> str: + playbook, drift_cards = load_all(self._skill_store) + drift_kind = None + if include_drift_cards and spec.drift_config is not None: + drift_kind = spec.drift_config.kind + result = retrieve( + query_tags=spec.tags, + drift_kind=drift_kind, + playbook=playbook, + drift_cards=drift_cards, + ) + return result.render(max_chars=800) + + def _maybe_persist_learned_entry(self) -> None: + """Append a PlaybookEntry on terminal success with a meaningful speedup. + + Failures to persist are logged but never re-raised: a training + rollout should not crash because the on-disk playbook is under + contention. The skill store itself is crash-safe (atomic writes + + file-lock) so at-most-once semantics are sufficient here. + """ + if self._skill_store is None: + return + rt = self._require_runtime() + if not rt.submitted: + return + if self.rubric is None: + return + scores = self.rubric.component_scores() + if scores.get("r_correct", 0.0) < 1.0: + return + spec = get_spec(rt.scenario_id) + raw_speedup = effective_speedup(rt) + # effective_speedup cannot return None here — rt.submitted is True + # so submitted_runtime_ms is populated — but we guard defensively. + # ``+∞`` (drift invalidated the baseline) is capped so the on-disk + # playbook doesn't serialize ``Infinity``, which would round-trip + # as a JSON parse error on load. + if raw_speedup is None or math.isinf(raw_speedup): + speedup_val = float(SPEEDUP_CAP_FOR_INFTY) + else: + speedup_val = float(raw_speedup) + entry = PlaybookEntry( + tag_set=spec.tags, + before_snippet=rt.instance.baseline_sql[:200], + after_snippet=(rt.submitted_sql or "")[:200], + avg_speedup=speedup_val, + scenario_family=spec.family, + source="learned", + ) + try: + self._skill_store.append_playbook(entry) + except Exception as exc: + _LOG.warning("skill-library append_playbook failed: %s", exc) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _grant_step_rebate_once(self, *, attr: str, rebate: float) -> None: + rt = self._require_runtime() + if getattr(rt, attr): + return + setattr(rt, attr, True) + rt.last_step_productive_rebate += rebate + + def _grant_step_rebate_for_table( + self, *, rewarded_tables_attr: str, table: str, rebate: float + ) -> None: + rt = self._require_runtime() + rewarded = getattr(rt, rewarded_tables_attr) + if table in rewarded: + return + rewarded.add(table) + rt.last_step_productive_rebate += rebate + + @staticmethod + def _pick_scenario_for_seed(seed: int) -> str: + """Deterministic round-robin over the sorted scenario registry.""" + ids = sorted(REGISTRY) + if not ids: + raise RuntimeError("no scenarios registered") + return ids[seed % len(ids)] + + def _require_runtime(self) -> RuntimeEpisodeState: + if self._runtime is None: + raise RuntimeError("SqlDriftEnvironment.reset() must be called before step()/state.") + return self._runtime + + def _close_existing_runtime(self) -> None: + if self._runtime is not None: + if self._runtime.connection_poisoned: + _LOG.error( + "skipping close for poisoned DuckDB connection in episode %s", + self._runtime.episode_id, + ) + else: + with contextlib.suppress(duckdb.Error): + self._runtime.conn.close() + self._runtime = None + + def _maybe_fire_drift(self) -> None: + """Apply drift when the step index crosses the schedule/cooldown threshold.""" + rt = self._require_runtime() + if rt.drift_fired: + return + if rt.drift_scheduled_step is None: + return + if rt.first_run_query_step is None: + return + cfg = rt.instance.drift_config + assert cfg is not None + minimum = max(rt.drift_scheduled_step, rt.first_run_query_step + cfg.cooldown_steps) + if rt.step_count < minimum: + return + self._fire_drift() + + def _fire_drift(self) -> None: + """Apply drift, author a changelog, and resolve the post-drift GT hash. + + Failure to recompute the post-drift GT hash is an authoring bug + (the scenario's ``gt_sql_postdrift`` must execute against the + just-mutated DB) and we re-raise loudly so it cannot silently + make every post-drift submission score ``r_correct=0``. + """ + rt = self._require_runtime() + cfg = rt.instance.drift_config + assert cfg is not None + apply_drift(rt.conn, cfg.kind, cfg.payload) + rt.drift_fired_step = rt.step_count + rt.phase = EpisodePhase.DRIFT_RECOVERY + rt.changelog_entries.append(author_changelog(cfg)) + + try: + rt.conn.execute(rt.instance.baseline_sql).fetchall() + rt.baseline_postdrift_raises = False + except duckdb.Error: + rt.baseline_postdrift_raises = True + + if rt.instance.gt_sql_postdrift is not None: + try: + rows = rt.conn.execute(rt.instance.gt_sql_postdrift).fetchall() + except duckdb.Error as exc: + raise RuntimeError( + f"scenario {rt.scenario_id!r}: authored gt_sql_postdrift failed " + f"after drift: {exc}" + ) from exc + rt.gt_result_hash_postdrift = canonical_row_hash(rows) + + # ------------------------------------------------------------------ + # Tool dispatch + # ------------------------------------------------------------------ + + def _dispatch(self, action: SqlDriftAction, *, timeout_s: float) -> ToolResult: + payload = action.payload + try: + if isinstance(payload, ListTablesPayload): + return self._handle_list_tables() + if isinstance(payload, DescribeTablePayload): + return self._handle_describe_table(payload) + if isinstance(payload, SampleRowsPayload): + return self._handle_sample_rows(payload) + if isinstance(payload, RunQueryPayload): + return self._handle_run_query(payload, timeout_s=timeout_s) + if isinstance(payload, ExplainQueryPayload): + return self._handle_explain_query(payload, timeout_s=timeout_s) + if isinstance(payload, ReadChangelogPayload): + return self._handle_read_changelog() + if isinstance(payload, SubmitRewritePayload): + return self._handle_submit_rewrite(payload, timeout_s=timeout_s) + if isinstance(payload, ConsultDBAPayload): + return self._handle_consult_dba(payload) + except duckdb.Error as exc: + return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) + except TimeoutError as exc: + return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) + # Unreachable — the discriminated-union validator rejects unknown payloads. + return ToolError( + code=ToolErrorCode.INVALID_TOOL_ARGUMENT, + message=f"unknown payload type: {type(payload).__name__}", + ) + + def _handle_list_tables(self) -> ListTablesResult: + rt = self._require_runtime() + rows = rt.conn.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'main' ORDER BY table_name" + ).fetchall() + self._grant_step_rebate_once(attr="listed_tables_rewarded", rebate=STEP_REBATE_LIST_TABLES) + self._mark_diagnostic() + return ListTablesResult(tables=[r[0] for r in rows]) + + def _handle_describe_table( + self, payload: DescribeTablePayload + ) -> DescribeTableResult | ToolError: + rt = self._require_runtime() + rows = rt.conn.execute( + "SELECT column_name, data_type FROM information_schema.columns " + "WHERE table_name = ? ORDER BY ordinal_position", + [payload.table], + ).fetchall() + if not rows: + return ToolError( + code=ToolErrorCode.UNKNOWN_TABLE, + message=f"unknown table: {payload.table}", + ) + self._grant_step_rebate_for_table( + rewarded_tables_attr="described_tables_rewarded", + table=payload.table, + rebate=STEP_REBATE_DESCRIBE_TABLE, + ) + self._mark_diagnostic() + return DescribeTableResult( + table=payload.table, + columns=[{"name": r[0], "type": r[1]} for r in rows], + ) + + def _handle_sample_rows(self, payload: SampleRowsPayload) -> SampleRowsResult | ToolError: + rt = self._require_runtime() + exists = rt.conn.execute( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", + [payload.table], + ).fetchone() + if not exists or exists[0] == 0: + return ToolError( + code=ToolErrorCode.UNKNOWN_TABLE, + message=f"unknown table: {payload.table}", + ) + cur = rt.conn.execute(f'SELECT * FROM "{payload.table}" LIMIT {payload.limit}') + columns = [d[0] for d in cur.description] if cur.description else [] + rows = [list(r) for r in cur.fetchall()] + self._grant_step_rebate_for_table( + rewarded_tables_attr="sampled_tables_rewarded", + table=payload.table, + rebate=STEP_REBATE_SAMPLE_ROWS, + ) + self._mark_diagnostic() + return SampleRowsResult(table=payload.table, columns=columns, rows=rows) + + def _handle_run_query( + self, payload: RunQueryPayload, *, timeout_s: float + ) -> RunQueryResult | ToolError: + rt = self._require_runtime() + sql = payload.sql + try: + _validate_read_only_sql(sql) + except ValueError as exc: + return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) + + # Drift timing: after a valid + # ``run_query`` attempt, the pre-drift probe invariant is + # satisfied regardless of whether the execution ultimately + # returned rows, raised, or was capped for size. Assigning + # *before* execution means truncation, DB errors, and timeouts + # can no longer suppress drift firing in later steps. + if rt.first_run_query_step is None: + rt.first_run_query_step = rt.step_count + + try: + result = execute_once_with_columns( + rt.conn, sql, timeout_s=timeout_s, max_rows=MAX_RESULT_ROWS + ) + except TimeoutError as exc: + return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) + except duckdb.Error as exc: + # Canonicalize *before* hashing so whitespace-/case-only + # variants of the same broken query count as the same repeat + # offence. canonicalize_sql falls back to a whitespace fold + # for SQL that sqlglot can't parse — still normalises the + # vast majority of "retried the same typo" cases. + failure_hash = canonical_row_hash([(canonicalize_sql(sql),)]) + count = rt.failed_query_counts.get(failure_hash, 0) + 1 + rt.failed_query_counts[failure_hash] = count + rt.failed_query_hashes.add(failure_hash) + rt.last_step_repeat_failing_query_count = count + rt.last_step_was_repeat_failing_query = count > 1 + return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) + + if result.truncated: + return ToolError( + code=ToolErrorCode.RESULT_TOO_LARGE, + message=( + f"result exceeded {MAX_RESULT_ROWS}-row cap — narrow the " + "projection, add a LIMIT, or aggregate" + ), + ) + + self._grant_step_rebate_once(attr="run_query_rewarded", rebate=STEP_REBATE_RUN_QUERY) + self._mark_diagnostic() + return RunQueryResult( + columns=result.columns, + rows=[list(r) for r in result.rows], + runtime_ms=result.elapsed_ms, + row_count=len(result.rows), + ) + + def _handle_explain_query( + self, payload: ExplainQueryPayload, *, timeout_s: float + ) -> ExplainQueryResult | ToolError: + rt = self._require_runtime() + try: + _validate_read_only_sql(payload.sql) + except ValueError as exc: + return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) + # EXPLAIN is plan-only (no data materialisation) but we still + # route it through the watchdog so a pathological query cannot + # burn the step budget past the caller's wall-clock deadline. + explain_rows, _ = execute_once_timed(rt.conn, f"EXPLAIN {payload.sql}", timeout_s=timeout_s) + plan = "\n".join(str(r[-1]) if r else "" for r in explain_rows) + self._grant_step_rebate_once( + attr="explain_query_rewarded", rebate=STEP_REBATE_EXPLAIN_QUERY + ) + self._mark_diagnostic() + return ExplainQueryResult(plan=plan[:10_000]) + + def _handle_read_changelog(self) -> ReadChangelogResult: + rt = self._require_runtime() + if rt.changelog_entries: + rt.drift_acknowledged = True + self._grant_step_rebate_once( + attr="changelog_rewarded_after_drift", + rebate=STEP_REBATE_READ_CHANGELOG, + ) + self._mark_diagnostic() + return ReadChangelogResult(entries=list(rt.changelog_entries)) + + def _handle_submit_rewrite( + self, payload: SubmitRewritePayload, *, timeout_s: float + ) -> SubmitRewriteResult | ToolError: + rt = self._require_runtime() + if not rt.diagnostic_actions_taken: + return ToolError( + code=ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE, + message=( + "submit_rewrite rejected: the agent must take at least one " + "diagnostic action (list_tables, describe_table, sample_rows, " + "run_query, explain_query, or read_changelog) before submitting." + ), + ) + sql = payload.sql + try: + _validate_read_only_sql(sql) + except ValueError as exc: + return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) + try: + agent_hash, elapsed_ms = execute_hash_timed(rt.conn, sql, timeout_s=timeout_s) + except TimeoutError as exc: + return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) + except duckdb.Error as exc: + return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) + gt_hash = ( + rt.gt_result_hash_postdrift + if rt.drift_fired and rt.gt_result_hash_postdrift is not None + else rt.gt_result_hash_predrift + ) + matches = agent_hash == gt_hash + + rt.submitted = True + rt.submitted_sql = sql + rt.submitted_sql_canonical = canonicalize_sql(sql) + rt.submitted_result_hash = agent_hash + rt.submitted_runtime_ms = elapsed_ms + rt.phase = EpisodePhase.FINALIZE + return SubmitRewriteResult( + accepted=True, + runtime_ms=elapsed_ms, + matches_ground_truth=matches, + ) + + def _handle_consult_dba(self, payload: ConsultDBAPayload) -> ConsultDBAResult | ToolError: + rt = self._require_runtime() + if not rt.dba_oracle_enabled: + return ToolError( + code=ToolErrorCode.INVALID_TOOL_ARGUMENT, + message="consult_dba disabled — set enable_dba_oracle=True at reset()", + ) + if not dba_oracle.has_hints(rt.scenario_id): + return ToolError( + code=ToolErrorCode.INVALID_TOOL_ARGUMENT, + message=f"no DBA hints registered for scenario={rt.scenario_id!r}", + ) + rt.consultations_used += 1 + tier = min(rt.consultations_used, 3) + hint = dba_oracle.get_hint(rt.scenario_id, tier) + del payload # question is free-text context only; hints are scenario-keyed. + return ConsultDBAResult(tier=tier, hint=hint) + + def _mark_diagnostic(self) -> None: + """Record a successful diagnostic tool call and advance the phase machine.""" + rt = self._require_runtime() + rt.diagnostic_actions_taken += 1 + if rt.phase == EpisodePhase.DIAGNOSE: + rt.phase = EpisodePhase.REWRITE + + +__all__ = [ + "DEFAULT_STEP_BUDGET", + "MAX_RESULT_ROWS", + "QUERY_TIMEOUT_S", + "SqlDriftEnvironment", +] diff --git a/skill_library/__init__.py b/skill_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..449018dd76667fb03c853e08cee799724f4471cd --- /dev/null +++ b/skill_library/__init__.py @@ -0,0 +1,54 @@ +"""Self-improvement library — pre-seeded playbook + on-disk learned store.""" + +from skill_library.entries import ( + DriftAdaptationCard, + PlaybookEntry, + RetrievalResult, +) +from skill_library.retrieval import ( + JACCARD_MIN, + jaccard, + retrieve, + top_k_drift_cards, + top_k_playbook, +) +from skill_library.seeds import PRESEED_DRIFT_CARDS, PRESEED_PLAYBOOK +from skill_library.store import DEFAULT_STORE_DIR, Store, cleanup_stale_session_dirs + + +def load_all( + store: Store | None = None, +) -> tuple[ + tuple[PlaybookEntry, ...], + tuple[DriftAdaptationCard, ...], +]: + """Union of pre-seeds and any entries persisted on disk. + + Returns ``(playbook, drift_cards)``. Order: pre-seeds first, then + learned entries, so deterministic retrieval tie-breaks prefer + the hand-authored pre-seeds when tags and speedup match exactly. + """ + learned_pb: tuple[PlaybookEntry, ...] = () + learned_dc: tuple[DriftAdaptationCard, ...] = () + if store is not None: + learned_pb = store.read_playbook() + learned_dc = store.read_drift_cards() + return PRESEED_PLAYBOOK + learned_pb, PRESEED_DRIFT_CARDS + learned_dc + + +__all__ = [ + "DEFAULT_STORE_DIR", + "DriftAdaptationCard", + "JACCARD_MIN", + "PRESEED_DRIFT_CARDS", + "PRESEED_PLAYBOOK", + "PlaybookEntry", + "RetrievalResult", + "Store", + "cleanup_stale_session_dirs", + "jaccard", + "load_all", + "retrieve", + "top_k_drift_cards", + "top_k_playbook", +] diff --git a/skill_library/entries.py b/skill_library/entries.py new file mode 100644 index 0000000000000000000000000000000000000000..073855af3df718e32331a8646ddafb5aedecbb91 --- /dev/null +++ b/skill_library/entries.py @@ -0,0 +1,93 @@ +"""Data classes for the self-improvement (skill) library. + +Two entry kinds — both immutable dataclasses so they can live on +frozen-dict caches and be safely shared across episodes. + +- :class:`PlaybookEntry` — a "before/after" SQL rewrite nugget, + tagged by anti-pattern + scenario family, with an empirical + speedup number. Populated by pre-seeds and extended at + terminal-success (``r_correct > 0 ∧ speedup > 1.2``). +- :class:`DriftAdaptationCard` — a drift-kind recovery card with a + symptom regex and a recovery template. Pre-seeded 1-per-drift-kind. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +DriftKind = Literal["column_rename", "date_format", "enum_rule", "field_deprecation"] + + +@dataclass(frozen=True) +class PlaybookEntry: + """A reusable SQL-rewrite recipe keyed by anti-pattern tags.""" + + tag_set: frozenset[str] + before_snippet: str + after_snippet: str + avg_speedup: float + scenario_family: str # matches scenarios.base.Family + # Optional provenance — "preseed" for hand-authored, "learned" for + # append-on-success entries. Used by the retrieval UI only. + source: Literal["preseed", "learned"] = "preseed" + + def render_hint(self, max_chars: int = 200) -> str: + """Render a one-liner suitable for inclusion in ``learned_hints``.""" + body = ( + f"[{self.scenario_family}] " + f"replace `{self.before_snippet[:60]}...` with " + f"`{self.after_snippet[:60]}...` " + f"(~{self.avg_speedup:.1f}x)" + ) + return body[:max_chars] + + +@dataclass(frozen=True) +class DriftAdaptationCard: + """A drift-kind recovery card.""" + + drift_kind: DriftKind + symptom_regex: str + recovery_template: str + success_rate: float = 0.0 + source: Literal["preseed", "learned"] = "preseed" + + def render_hint(self, max_chars: int = 200) -> str: + body = ( + f"[drift:{self.drift_kind}] " + f"symptom=/{self.symptom_regex}/ → " + f"{self.recovery_template[:120]}" + ) + return body[:max_chars] + + +@dataclass(frozen=True) +class RetrievalResult: + """Top-k blend of playbook hits + drift cards for one retrieval call.""" + + playbook: tuple[PlaybookEntry, ...] = field(default_factory=tuple) + drift_cards: tuple[DriftAdaptationCard, ...] = field(default_factory=tuple) + + def render(self, max_chars: int = 800) -> str: + """Concatenate rendered hints, truncated to ``max_chars``. + + Deterministic ordering: playbook entries first (by descending + ``avg_speedup``, ties broken by ``before_snippet``), then drift + cards (by descending ``success_rate``, ties by ``drift_kind``). + """ + lines: list[str] = [] + for e in self.playbook: + lines.append("- " + e.render_hint(max_chars=200)) + for c in self.drift_cards: + lines.append("- " + c.render_hint(max_chars=200)) + out = "\n".join(lines) + return out[:max_chars] + + +__all__ = [ + "DriftAdaptationCard", + "DriftKind", + "PlaybookEntry", + "RetrievalResult", +] diff --git a/skill_library/retrieval.py b/skill_library/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..99b1c7525aff9fce039c2655593e0f1d962499e8 --- /dev/null +++ b/skill_library/retrieval.py @@ -0,0 +1,85 @@ +"""Pure-function Jaccard top-k retrieval over tag sets. + +No embeddings, no torch, no tokenizer. Deterministic — same inputs +produce the same ranking, same top-k, same tie-break. + +A conservative Jaccard threshold (0.3) limits retrieval noise when +broad pre-seeds would otherwise match every scenario. +""" + +from __future__ import annotations + +from collections.abc import Iterable + +from skill_library.entries import ( + DriftAdaptationCard, + PlaybookEntry, + RetrievalResult, +) + +JACCARD_MIN: float = 0.3 + + +def jaccard(a: frozenset[str], b: frozenset[str]) -> float: + """Standard Jaccard on sets.""" + if not a and not b: + return 1.0 + if not a or not b: + return 0.0 + return len(a & b) / len(a | b) + + +def top_k_playbook( + query_tags: frozenset[str], + entries: Iterable[PlaybookEntry], + k: int = 3, + *, + min_overlap: float = JACCARD_MIN, +) -> tuple[PlaybookEntry, ...]: + """Top-k playbook entries by Jaccard overlap with ``query_tags``. + + Ties broken by descending ``avg_speedup`` then by ``before_snippet`` + lexicographic order so the result is stable across runs. + """ + scored = [(jaccard(query_tags, e.tag_set), e) for e in entries] + scored = [(j, e) for j, e in scored if j >= min_overlap] + scored.sort(key=lambda t: (-t[0], -t[1].avg_speedup, t[1].before_snippet)) + return tuple(e for _, e in scored[:k]) + + +def top_k_drift_cards( + drift_kind: str | None, + cards: Iterable[DriftAdaptationCard], + k: int = 1, +) -> tuple[DriftAdaptationCard, ...]: + """Filter cards by exact drift_kind match, sorted by success_rate desc.""" + if drift_kind is None: + return () + matches = [c for c in cards if c.drift_kind == drift_kind] + matches.sort(key=lambda c: (-c.success_rate, c.drift_kind)) + return tuple(matches[:k]) + + +def retrieve( + query_tags: frozenset[str], + drift_kind: str | None, + playbook: Iterable[PlaybookEntry], + drift_cards: Iterable[DriftAdaptationCard], + *, + playbook_k: int = 3, + drift_k: int = 1, +) -> RetrievalResult: + """Combined retrieval: top-k playbook + top-k drift cards.""" + return RetrievalResult( + playbook=top_k_playbook(query_tags, playbook, k=playbook_k), + drift_cards=top_k_drift_cards(drift_kind, drift_cards, k=drift_k), + ) + + +__all__ = [ + "JACCARD_MIN", + "jaccard", + "retrieve", + "top_k_drift_cards", + "top_k_playbook", +] diff --git a/skill_library/seeds.py b/skill_library/seeds.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4b8bdfb5f708abb7061c9f7bfbb3883a12ab63 --- /dev/null +++ b/skill_library/seeds.py @@ -0,0 +1,126 @@ +"""Hand-authored pre-seed library — 8 playbook entries + 4 drift cards. + +Tag sets here match the tags on each corresponding scenario +so a Jaccard top-k against the scenario's tag set returns +a relevant entry for every one of the 10 scenarios — with two +generic fallbacks for coverage on novel anti-patterns. +""" + +from __future__ import annotations + +from skill_library.entries import DriftAdaptationCard, PlaybookEntry + +PRESEED_PLAYBOOK: tuple[PlaybookEntry, ...] = ( + # 1. Correlated subquery → LEFT JOIN + GROUP BY + PlaybookEntry( + tag_set=frozenset({"correlated_subquery", "projection_subquery"}), + before_snippet=( + "SELECT u.*, (SELECT COUNT(*) FROM orders o WHERE o.user_id=u.id) FROM users u" + ), + after_snippet=( + "SELECT u.*, COALESCE(c.n, 0) FROM users u " + "LEFT JOIN (SELECT user_id, COUNT(*) n FROM orders GROUP BY user_id) c " + "ON c.user_id = u.id" + ), + avg_speedup=6.0, + scenario_family="ecommerce", + ), + # 2. SELECT * + join → project only needed columns + PlaybookEntry( + tag_set=frozenset({"select_star", "over_projection", "join"}), + before_snippet="SELECT * FROM a JOIN b ON a.id=b.a_id", + after_snippet="SELECT a.id, a.name, b.amount FROM a JOIN b ON a.id=b.a_id", + avg_speedup=2.5, + scenario_family="ecommerce", + ), + # 3. Cartesian join — add explicit ON clause + PlaybookEntry( + tag_set=frozenset({"cartesian", "missing_join_condition"}), + before_snippet="SELECT * FROM a, b WHERE a.region = 'US'", + after_snippet="SELECT a.col FROM a JOIN b ON a.id = b.a_id WHERE a.region = 'US'", + avg_speedup=50.0, + scenario_family="events", + ), + # 4. DISTINCT on GROUP BY — drop one + PlaybookEntry( + tag_set=frozenset({"distinct", "redundant_distinct", "group_by"}), + before_snippet="SELECT DISTINCT tenant_id, count(*) FROM logs GROUP BY tenant_id", + after_snippet="SELECT tenant_id, count(*) FROM logs GROUP BY tenant_id", + avg_speedup=1.4, + scenario_family="saas_logs", + ), + # 5. Nested IN-subquery → JOIN + PlaybookEntry( + tag_set=frozenset({"nested_subquery", "in_subquery"}), + before_snippet="WHERE id IN (SELECT x_id FROM x WHERE ... )", + after_snippet="JOIN x ON x.x_id = table.id WHERE ...", + avg_speedup=3.0, + scenario_family="cms", + ), + # 6. HAVING filter on groupable column → push to WHERE + PlaybookEntry( + tag_set=frozenset({"having_as_where", "aggregate_filter"}), + before_snippet="GROUP BY x, status HAVING status = 'fulfilled'", + after_snippet="WHERE status = 'fulfilled' GROUP BY x", + avg_speedup=2.0, + scenario_family="ecommerce", + ), + # 7. Generic: prefer JOINs over correlated subqueries + PlaybookEntry( + tag_set=frozenset({"subquery", "generic"}), + before_snippet="scalar subquery in SELECT list", + after_snippet="LEFT JOIN with aggregated CTE", + avg_speedup=4.0, + scenario_family="ecommerce", + ), + # 8. Generic: project only used columns + PlaybookEntry( + tag_set=frozenset({"over_projection", "generic"}), + before_snippet="SELECT *", + after_snippet="SELECT ", + avg_speedup=1.8, + scenario_family="ecommerce", + ), +) + + +PRESEED_DRIFT_CARDS: tuple[DriftAdaptationCard, ...] = ( + DriftAdaptationCard( + drift_kind="column_rename", + symptom_regex=r'column ".+" does not exist', + recovery_template=( + "Read the changelog, update every identifier referencing the old column, and resubmit." + ), + success_rate=0.9, + ), + DriftAdaptationCard( + drift_kind="date_format", + symptom_regex=r"Could not convert string .+ to TIMESTAMP|BIGINT", + recovery_template=( + "Epoch-ms columns are BIGINT; cast your filter bounds with " + "`EXTRACT(EPOCH FROM TIMESTAMP '...') * 1000` or use numeric literals." + ), + success_rate=0.85, + ), + DriftAdaptationCard( + drift_kind="enum_rule", + symptom_regex=r"(empty|zero) result set on filter `... = 'active'`", + recovery_template=( + "A single enum value may have been split into several; use `IN " + "('ACTIVE', 'ACTIVE_V2')` instead of equality." + ), + success_rate=0.8, + ), + DriftAdaptationCard( + drift_kind="field_deprecation", + symptom_regex=r'column ".+" does not exist|non-existent column', + recovery_template=( + "The inline field was replaced by a FK; JOIN the lookup table and " + "project the human-readable name from there." + ), + success_rate=0.75, + ), +) + + +__all__ = ["PRESEED_DRIFT_CARDS", "PRESEED_PLAYBOOK"] diff --git a/skill_library/store.py b/skill_library/store.py new file mode 100644 index 0000000000000000000000000000000000000000..53019f92cbb117182d59970b1903e765f0c1d23c --- /dev/null +++ b/skill_library/store.py @@ -0,0 +1,257 @@ +"""JSON-backed playbook and drift-card store with file locking. + +Each store file is a JSON array. Appends go through a single locked +read-modify-write cycle that writes to a sibling ``*.tmp`` file and +atomically ``os.replace``s it onto the target path, so a crash can only +leave either the old array or the new one — never a truncated file. + +The lock is held on a dedicated ``*.lock`` file via ``fcntl.flock`` with +a caller-configurable timeout (default 5s). We never lock the data file +itself: that way an ``os.replace`` inside the critical section can't +race against a reader holding a shared lock on the old inode. + +Reads are cached by mtime so hot-path episodes don't re-parse the file +on every ``reset()``. Corrupt trailers (from a pre-atomic-write era or +a partial disk write) log a warning and fall back to empty — we prefer +a running trainer over one that dies because of a bad card. +""" + +from __future__ import annotations + +import contextlib +import json +import os +import time +from collections.abc import Callable, Iterator +from dataclasses import asdict +from pathlib import Path +from typing import Any, Literal, TypeVar, cast + +from skill_library.entries import DriftAdaptationCard, PlaybookEntry +from utilities.logger import get_module_logger + +_LOG = get_module_logger(__name__) + +DEFAULT_STORE_DIR = Path("outputs") / "skill_library" +PLAYBOOK_FILENAME = "playbook.json" +DRIFT_CARDS_FILENAME = "drift_cards.json" +DEFAULT_LOCK_TIMEOUT_S: float = 5.0 + +T = TypeVar("T") + + +try: + import fcntl + + def _try_lock_exclusive(fh: Any) -> bool: + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + return True + except BlockingIOError: + return False + + def _unlock(fh: Any) -> None: + fcntl.flock(fh.fileno(), fcntl.LOCK_UN) + + _HAS_FCNTL = True + +except ImportError: + _HAS_FCNTL = False + + def _try_lock_exclusive(fh: Any) -> bool: + return True + + def _unlock(fh: Any) -> None: + return None + + +@contextlib.contextmanager +def _locked(path: Path, timeout_s: float) -> Iterator[None]: + """Poll-acquire an exclusive flock on ``path`` within ``timeout_s``.""" + path.parent.mkdir(parents=True, exist_ok=True) + deadline = time.monotonic() + timeout_s + with path.open("a+") as fh: + while not _try_lock_exclusive(fh): + if time.monotonic() >= deadline: + raise TimeoutError(f"could not acquire {path} within {timeout_s}s") + time.sleep(0.02) + try: + yield + finally: + if _HAS_FCNTL: + _unlock(fh) + + +def _atomic_write_json(path: Path, payload: list[dict[str, Any]]) -> None: + tmp = path.with_suffix(path.suffix + ".tmp") + text = json.dumps(payload, indent=2) + with tmp.open("w", encoding="utf-8") as fh: + fh.write(text) + fh.flush() + os.fsync(fh.fileno()) + os.replace(tmp, path) + + +def _read_json_array(path: Path) -> list[Any]: + if not path.exists(): + return [] + try: + raw = path.read_text(encoding="utf-8") + except OSError as exc: + _LOG.warning("skill-store read failed for %s: %s", path, exc) + return [] + if not raw.strip(): + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError as exc: + _LOG.warning("skill-store corrupt at %s (%s); returning empty", path, exc) + return [] + return data if isinstance(data, list) else [] + + +class Store: + """Append-only JSON store for learned playbook entries + drift cards.""" + + def __init__( + self, + directory: Path | None = None, + lock_timeout_s: float = DEFAULT_LOCK_TIMEOUT_S, + ) -> None: + self.dir = Path(directory) if directory is not None else DEFAULT_STORE_DIR + self.lock_timeout_s = lock_timeout_s + self._playbook_cache: tuple[tuple[PlaybookEntry, ...], float] | None = None + self._drift_cache: tuple[tuple[DriftAdaptationCard, ...], float] | None = None + + def playbook_path(self) -> Path: + return self.dir / PLAYBOOK_FILENAME + + def drift_cards_path(self) -> Path: + return self.dir / DRIFT_CARDS_FILENAME + + def read_playbook(self) -> tuple[PlaybookEntry, ...]: + return self._read_cached( + self.playbook_path(), + cache_attr="_playbook_cache", + decode=_entry_from_dict, + ) + + def read_drift_cards(self) -> tuple[DriftAdaptationCard, ...]: + return self._read_cached( + self.drift_cards_path(), + cache_attr="_drift_cache", + decode=lambda d: DriftAdaptationCard(**d), + ) + + def append_playbook(self, entry: PlaybookEntry) -> None: + self._locked_append( + self.playbook_path(), + encode_new=_entry_to_dict, + new_item=entry, + ) + self._playbook_cache = None + + def append_drift_card(self, card: DriftAdaptationCard) -> None: + self._locked_append( + self.drift_cards_path(), + encode_new=asdict, + new_item=card, + ) + self._drift_cache = None + + def _read_cached( + self, + path: Path, + *, + cache_attr: str, + decode: Callable[[dict[str, Any]], T], + ) -> tuple[T, ...]: + mtime = _safe_mtime(path) + # ``getattr``/``setattr`` is intentional — the same implementation + # services both the playbook and drift-card caches, whose Python + # types differ. The cast below restores the precise + # ``(tuple[T, ...], float) | None`` shape for mypy. + cached = cast("tuple[tuple[T, ...], float] | None", getattr(self, cache_attr)) + if cached is not None and cached[1] == mtime: + return cached[0] + items: list[T] = [] + for d in _read_json_array(path): + try: + items.append(decode(d)) + except (TypeError, KeyError, ValueError) as exc: + _LOG.warning("skipping malformed store entry %s: %s", d, exc) + tup = tuple(items) + setattr(self, cache_attr, (tup, mtime)) + return tup + + def _locked_append( + self, + path: Path, + *, + encode_new: Callable[[Any], dict[str, Any]], + new_item: Any, + ) -> None: + self.dir.mkdir(parents=True, exist_ok=True) + lock_path = path.with_suffix(path.suffix + ".lock") + with _locked(lock_path, self.lock_timeout_s): + existing = _read_json_array(path) + existing.append(encode_new(new_item)) + _atomic_write_json(path, existing) + + +def _safe_mtime(path: Path) -> float: + try: + return path.stat().st_mtime + except FileNotFoundError: + return 0.0 + + +def _entry_to_dict(e: PlaybookEntry) -> dict[str, Any]: + return { + "tag_set": sorted(e.tag_set), + "before_snippet": e.before_snippet, + "after_snippet": e.after_snippet, + "avg_speedup": e.avg_speedup, + "scenario_family": e.scenario_family, + "source": e.source, + } + + +def _entry_from_dict(d: dict[str, Any]) -> PlaybookEntry: + source: Literal["preseed", "learned"] = d.get("source", "learned") + return PlaybookEntry( + tag_set=frozenset(d.get("tag_set") or []), + before_snippet=d["before_snippet"], + after_snippet=d["after_snippet"], + avg_speedup=float(d["avg_speedup"]), + scenario_family=d["scenario_family"], + source=source, + ) + + +def cleanup_stale_session_dirs(root: Path, ttl_hours: float) -> int: + """Remove session subdirectories under *root* whose mtime is older than *ttl_hours*. + + Returns the number of directories removed. Errors on individual + subdirectories are logged and skipped so a single bad entry cannot abort + the sweep. Pass ``ttl_hours=0`` to disable (returns 0 immediately). + """ + import shutil + + if ttl_hours <= 0 or not root.exists(): + return 0 + cutoff = time.time() - ttl_hours * 3600 + removed = 0 + for session_dir in root.iterdir(): + if not session_dir.is_dir(): + continue + try: + if session_dir.stat().st_mtime < cutoff: + shutil.rmtree(session_dir, ignore_errors=True) + removed += 1 + except OSError as exc: + _LOG.warning("cleanup_stale_session_dirs: skipping %s: %s", session_dir, exc) + return removed + + +__all__ = ["DEFAULT_LOCK_TIMEOUT_S", "DEFAULT_STORE_DIR", "Store", "cleanup_stale_session_dirs"] diff --git a/sql_drift_grpo_training.ipynb b/sql_drift_grpo_training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..01d66f453fa9fd544b6e894448c9de0355ea3aff --- /dev/null +++ b/sql_drift_grpo_training.ipynb @@ -0,0 +1,373 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQLDrift GRPO Training Notebook\n", + "\n", + "Runnable hackathon training artifact for SQLDrift. This notebook uses the repository's existing `training.grpo_train` entrypoint, which wires Unsloth, Hugging Face TRL `GRPOTrainer`, Transformers, PEFT LoRA, and the SQLDrift OpenEnv tool environment.\n", + "\n", + "Assumptions:\n", + "\n", + "- You run this on a GPU runtime, preferably Colab with an NVIDIA T4/A10/L4 or better.\n", + "- `SQL_DRIFT_ENV_URL` points at the deployed SQLDrift HF Space, for example `https://your-space.hf.space`.\n", + "- The notebook is run from the repository root, or `SQL_DRIFT_REPO_URL` is set so Colab can clone the repo.\n", + "\n", + "Success criteria:\n", + "\n", + "1. The deployed environment health check succeeds.\n", + "2. GRPO trains for the configured number of steps against live SQLDrift rollouts.\n", + "3. A LoRA adapter/checkpoint is written under `outputs/checkpoints/grpo_hackathon/`.\n", + "4. Durable evidence files are written under `training/evidence/`: `grpo_metrics.csv`, `grpo_loss_curve.png`, and `grpo_reward_curve.png`.\n", + "\n", + "After pushing this notebook to GitHub, open it in Colab by using:\n", + "`https://colab.research.google.com/github///blob/main/training/sql_drift_grpo_training.ipynb`\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Install The Training Stack\n", + "\n", + "If this notebook is launched directly in Colab, set `SQL_DRIFT_REPO_URL` first. If it is already running from the repository root, the clone step is skipped.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import os\n", + "import subprocess\n", + "import sys\n", + "\n", + "REPO_URL = os.environ.get(\"SQL_DRIFT_REPO_URL\", \"\").strip()\n", + "repo_root = Path.cwd()\n", + "\n", + "if not (repo_root / \"pyproject.toml\").exists():\n", + " if not REPO_URL:\n", + " raise RuntimeError(\n", + " \"Set SQL_DRIFT_REPO_URL to your GitHub repo URL, or run this notebook \"\n", + " \"from the sql_drift_env repository root.\"\n", + " )\n", + " subprocess.check_call([\"git\", \"clone\", REPO_URL, \"sql_drift_env\"])\n", + " os.chdir(\"sql_drift_env\")\n", + " repo_root = Path.cwd()\n", + "\n", + "subprocess.check_call(\n", + " [\n", + " sys.executable,\n", + " \"-m\",\n", + " \"pip\",\n", + " \"install\",\n", + " \"-q\",\n", + " \"-e\",\n", + " \".[train]\",\n", + " \"matplotlib\",\n", + " \"pandas\",\n", + " \"peft\",\n", + " \"bitsandbytes\",\n", + " ]\n", + ")\n", + "\n", + "print(f\"Repo root: {repo_root}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Point At The Deployed SQLDrift Environment\n", + "\n", + "Set `SQL_DRIFT_ENV_URL` to the HF Space URL before running the cell, for example:\n", + "\n", + "```python\n", + "import os\n", + "os.environ[\"SQL_DRIFT_ENV_URL\"] = \"https://your-space.hf.space\"\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import urllib.request\n", + "\n", + "ENV_BASE_URL = os.environ.get(\"SQL_DRIFT_ENV_URL\", \"\").strip()\n", + "if not ENV_BASE_URL:\n", + " raise RuntimeError(\"Set SQL_DRIFT_ENV_URL to the deployed SQLDrift HF Space URL.\")\n", + "\n", + "health_url = ENV_BASE_URL.rstrip(\"/\") + \"/health\"\n", + "with urllib.request.urlopen(health_url, timeout=30) as response:\n", + " body = response.read().decode(\"utf-8\")\n", + "\n", + "print(f\"Health check OK: {health_url}\")\n", + "print(body[:500])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Configure A Minimal GRPO Run\n", + "\n", + "Defaults are intentionally small enough for a hackathon Colab run while still producing non-trivial curves. Increase `SQL_DRIFT_GRPO_MAX_STEPS` for final evidence.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import asdict\n", + "from pathlib import Path\n", + "\n", + "from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig\n", + "\n", + "MAX_STEPS = int(os.environ.get(\"SQL_DRIFT_GRPO_MAX_STEPS\", \"80\"))\n", + "GROUP_SIZE = int(os.environ.get(\"SQL_DRIFT_GRPO_GROUP_SIZE\", \"4\"))\n", + "SEED = int(os.environ.get(\"SQL_DRIFT_GRPO_SEED\", \"7\"))\n", + "OUTPUT_DIR = \"outputs/checkpoints/grpo_hackathon\"\n", + "EVIDENCE_DIR = Path(\"training/evidence\")\n", + "EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# P0 audit fixes (design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md \u00a72.6):\n", + "# - max_seq_length 4096 \u2192 2048 (T4 16 GB OOM safety)\n", + "# - max_completion_length 1024 \u2192 256 (JSON tool envelopes need ~200 tok)\n", + "# - max_prompt_length kept large at 1536 (system prompt + tool history)\n", + "# - save_steps max(20, MAX_STEPS // 4) so a step-79 crash still leaves a checkpoint\n", + "# - default model swap to unsloth/Qwen3-1.7B-unsloth-bnb-4bit per audit \u00a74\n", + "cfg = GRPOConfig(\n", + " model_name=os.environ.get(\"SQL_DRIFT_MODEL_NAME\", \"unsloth/Qwen3-1.7B-unsloth-bnb-4bit\"),\n", + " env_base_url=ENV_BASE_URL,\n", + " output_dir=OUTPUT_DIR,\n", + " max_steps=MAX_STEPS,\n", + " group_size=GROUP_SIZE,\n", + " gradient_accumulation_steps=int(os.environ.get(\"SQL_DRIFT_GRAD_ACCUM\", \"1\")),\n", + " learning_rate=float(os.environ.get(\"SQL_DRIFT_LR\", \"5e-6\")),\n", + " warmup_steps=min(10, max(1, MAX_STEPS // 10)),\n", + " save_steps=max(20, MAX_STEPS // 4),\n", + " logging_steps=1,\n", + " seed=SEED,\n", + " max_seq_length=2048,\n", + " max_prompt_length=1536,\n", + " max_completion_length=256,\n", + " curriculum=CurriculumConfig(scenarios=ALL_SCENARIOS, mode=\"weighted\", weights=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2)),\n", + ")\n", + "\n", + "(EVIDENCE_DIR / \"grpo_config.json\").write_text(json.dumps(asdict(cfg), indent=2))\n", + "print(json.dumps(asdict(cfg), indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Run GRPO Training\n", + "\n", + "This calls `training.grpo_train.train(cfg)`, which loads the Unsloth model, attaches LoRA adapters, builds the curriculum dataset, and lets TRL drive multi-turn OpenEnv rollouts through `SqlDriftToolEnv`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4a. Free CPU smoke (max_steps=2)\n", + "\n", + "Pre-flight: validate env URL, dataset build, callback registration, and log file plumbing\n", + "**before** spending GPU time. P0 fix per `design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md` \u00a75.3.\n", + "The assertion at the end aborts the notebook if the smoke produced no records." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "from dataclasses import replace\n", + "\n", + "smoke_cfg = replace(\n", + " cfg,\n", + " max_steps=2,\n", + " save_steps=2,\n", + " output_dir=str(EVIDENCE_DIR / \"smoke\"),\n", + ")\n", + "print(f\"\\n=== SMOKE TEST: {smoke_cfg.max_steps} steps, output={smoke_cfg.output_dir} ===\\n\")\n", + "\n", + "from training.grpo_train import train as _train\n", + "\n", + "_smoke_trainer = _train(smoke_cfg)\n", + "_smoke_history = list(getattr(_smoke_trainer.state, \"log_history\", []))\n", + "assert _smoke_history, \"Smoke produced no log records \u2014 abort before the paid run.\"\n", + "print(f\"Smoke OK \u2014 {len(_smoke_history)} log records. Proceed to the paid run below.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from training.grpo_train import train\n", + "\n", + "trainer = train(cfg)\n", + "history = list(getattr(trainer.state, \"log_history\", []))\n", + "print(f\"Captured {len(history)} trainer log records\")\n", + "print(f\"Checkpoint directory: {OUTPUT_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Save Evidence Curves\n", + "\n", + "The evaluator requires committed image files, not only WandB links or inline notebook plots. This cell writes PNGs and the raw metrics CSV into `training/evidence/`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "df = pd.DataFrame(history)\n", + "if df.empty:\n", + " raise RuntimeError(\"Trainer log history is empty; training did not emit metrics.\")\n", + "\n", + "if \"step\" not in df.columns:\n", + " df[\"step\"] = range(len(df))\n", + "\n", + "metrics_csv = EVIDENCE_DIR / \"grpo_metrics.csv\"\n", + "df.to_csv(metrics_csv, index=False)\n", + "\n", + "\n", + "def _numeric_columns(frame: pd.DataFrame) -> dict[str, pd.Series]:\n", + " out: dict[str, pd.Series] = {}\n", + " for column in frame.columns:\n", + " series = pd.to_numeric(frame[column], errors=\"coerce\")\n", + " if series.notna().any():\n", + " out[column] = series\n", + " return out\n", + "\n", + "\n", + "def _choose_metric(\n", + " frame: pd.DataFrame,\n", + " *,\n", + " preferred: list[str],\n", + " required_token: str,\n", + " exclude_tokens: tuple[str, ...] = (),\n", + ") -> tuple[str, pd.Series]:\n", + " numeric = _numeric_columns(frame)\n", + " for column in preferred:\n", + " if column in numeric:\n", + " return column, numeric[column]\n", + " for column, series in numeric.items():\n", + " lowered = column.lower()\n", + " if required_token in lowered and not any(token in lowered for token in exclude_tokens):\n", + " return column, series\n", + " available = \", \".join(sorted(numeric))\n", + " raise RuntimeError(\n", + " f\"Could not find a numeric {required_token!r} metric. Available: {available}\"\n", + " )\n", + "\n", + "\n", + "def _plot_curve(column: str, values: pd.Series, ylabel: str, title: str, path: Path) -> None:\n", + " plot_df = pd.DataFrame(\n", + " {\n", + " \"step\": pd.to_numeric(df[\"step\"], errors=\"coerce\"),\n", + " \"value\": values,\n", + " }\n", + " ).dropna()\n", + " if plot_df.empty:\n", + " raise RuntimeError(f\"Metric {column!r} has no plottable values.\")\n", + "\n", + " plt.figure(figsize=(8, 4.5))\n", + " plt.plot(plot_df[\"step\"], plot_df[\"value\"], marker=\"o\", linewidth=1.4, label=column)\n", + " if len(plot_df) >= 5:\n", + " window = min(10, max(2, len(plot_df) // 5))\n", + " rolling = plot_df[\"value\"].rolling(window=window, min_periods=1).mean()\n", + " plt.plot(plot_df[\"step\"], rolling, linewidth=2.2, label=f\"{window}-point rolling mean\")\n", + " plt.xlabel(\"GRPO step\")\n", + " plt.ylabel(ylabel)\n", + " plt.title(title)\n", + " plt.grid(alpha=0.25)\n", + " plt.legend()\n", + " plt.tight_layout()\n", + " plt.savefig(path, dpi=180, bbox_inches=\"tight\")\n", + " plt.show()\n", + "\n", + "\n", + "loss_col, loss_values = _choose_metric(\n", + " df,\n", + " preferred=[\"loss\", \"train_loss\"],\n", + " required_token=\"loss\",\n", + ")\n", + "reward_col, reward_values = _choose_metric(\n", + " df,\n", + " preferred=[\"reward\", \"rewards/mean\", \"mean_reward\", \"reward_mean\", \"train/reward\"],\n", + " required_token=\"reward\",\n", + " exclude_tokens=(\"std\", \"variance\"),\n", + ")\n", + "\n", + "loss_png = EVIDENCE_DIR / \"grpo_loss_curve.png\"\n", + "reward_png = EVIDENCE_DIR / \"grpo_reward_curve.png\"\n", + "_plot_curve(loss_col, loss_values, \"loss\", \"SQLDrift GRPO Loss Curve\", loss_png)\n", + "_plot_curve(reward_col, reward_values, \"episode reward\", \"SQLDrift GRPO Reward Curve\", reward_png)\n", + "\n", + "print(f\"Wrote metrics: {metrics_csv}\")\n", + "print(f\"Wrote loss curve: {loss_png}\")\n", + "print(f\"Wrote reward curve: {reward_png}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Commit The Evidence\n", + "\n", + "Once training completes, commit the notebook plus the generated evidence images:\n", + "\n", + "```bash\n", + "git add training/sql_drift_grpo_training.ipynb training/evidence/grpo_loss_curve.png training/evidence/grpo_reward_curve.png training/evidence/grpo_metrics.csv training/evidence/grpo_config.json\n", + "git commit -m \"Add SQLDrift GRPO training notebook and evidence\"\n", + "```\n", + "\n", + "Do not rely on WandB-only links or plots that live only in notebook output cells; the PNG files above are the durable artifacts evaluators can validate.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7884ff7bf85671733e29e813685a5a9cfa5f8b01 --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,27 @@ +"""TRL/Unsloth GRPO training harness — P12/P13. + +Public surface: + +* :class:`training.config.GRPOConfig` / :class:`training.config.CurriculumConfig` +* :func:`training.prompt.render_system_prompt` +* :class:`training.random_agent.RandomAgent` +* :func:`training.grpo_train.train` (requires GPU + ``[train]`` extra) +""" + +from __future__ import annotations + +from training.config import ALL_SCENARIOS, CurriculumConfig, GRPOConfig +from training.prompt import ( + render_prompt_from_observation, + render_system_prompt, +) +from training.random_agent import RandomAgent + +__all__ = [ + "ALL_SCENARIOS", + "CurriculumConfig", + "GRPOConfig", + "RandomAgent", + "render_prompt_from_observation", + "render_system_prompt", +] diff --git a/training/config.py b/training/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5455ee349ff64553030c60449347ce6ecdd97453 --- /dev/null +++ b/training/config.py @@ -0,0 +1,127 @@ +"""Training configuration dataclasses. + +Holds every knob the :mod:`training.grpo_train` script or the eval CLI +needs, as plain, frozen dataclasses so they serialize cleanly to JSON +for experiment manifests. + +Deliberately lightweight: do not import ``trl`` / ``unsloth`` / +``transformers`` at module import time. Those libraries are CUDA-heavy +and optional. ``grpo_train.py`` resolves them lazily. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +from utilities.env_loader import env_str + + +def _load_all_scenarios() -> tuple[str, ...]: + from scenarios import iter_specs + + return tuple(spec.scenario_id for spec in iter_specs()) + + +# Derived from the live registry so training defaults stay in sync with +# the scenarios actually shipped under ``scenarios/``. +ALL_SCENARIOS: tuple[str, ...] = _load_all_scenarios() + + +@dataclass(frozen=True) +class CurriculumConfig: + """Scenario sampling policy for GRPO rollouts. + + ``mode="uniform"`` samples each id in :attr:`scenarios` with equal + probability. ``mode="weighted"`` uses :attr:`weights` (must be the + same length as :attr:`scenarios`) — useful for over-sampling drift + scenarios early in training. ``mode="static_order"`` iterates the + list round-robin (handy for reproducing eval-style runs). + """ + + scenarios: tuple[str, ...] = ALL_SCENARIOS + mode: Literal["uniform", "weighted", "static_order"] = "uniform" + weights: tuple[float, ...] | None = None + seed_range: tuple[int, int] = (0, 2**31 - 1) + + def __post_init__(self) -> None: + if not self.scenarios: + raise ValueError("CurriculumConfig.scenarios must be non-empty") + if self.mode == "weighted": + if self.weights is None or len(self.weights) != len(self.scenarios): + raise ValueError("mode='weighted' requires weights of the same length as scenarios") + if any(w < 0 for w in self.weights): + raise ValueError("weights must all be >= 0") + if sum(self.weights) <= 0: + raise ValueError("at least one weight must be > 0") + lo, hi = self.seed_range + if lo < 0 or hi <= lo: + raise ValueError("seed_range must be (lo >= 0, hi > lo)") + + +@dataclass(frozen=True) +class GRPOConfig: + """Top-level training config for the GRPO skeleton. + + Defaults: Qwen3-1.7B (Unsloth bnb-4bit) + LoRA r=16, group size 4, + 500 max steps — but every knob is override-able from the CLI or a + JSON/YAML manifest. Model swap from Qwen2.5-1.5B → Qwen3-1.7B per + design/RL_HACKATHON_AUDIT_AND_RUNBOOK.md (section 4: newer family with + documented agentic tool-calling, comparable footprint). + """ + + model_name: str = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit" + max_seq_length: int = 4096 + load_in_4bit: bool = True + + lora_r: int = 16 + lora_alpha: int = 16 + lora_dropout: float = 0.0 + lora_target_modules: tuple[str, ...] = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ) + + # GRPO knobs + group_size: int = 4 + learning_rate: float = 5e-6 + max_steps: int = 500 + gradient_accumulation_steps: int = 1 + warmup_steps: int = 10 + max_prompt_length: int = 2048 + max_completion_length: int = 1024 + temperature: float = 0.8 + top_p: float = 0.95 + seed: int = 0 + + # Env wiring + env_base_url: str = env_str("SQL_DRIFT_ENV_URL", "http://localhost:8000") + episode_step_budget: int = 25 + dba_oracle_enabled: bool = False + + # IO + output_dir: str = "outputs/grpo_run" + logging_steps: int = 1 + save_steps: int = 100 + + curriculum: CurriculumConfig = field(default_factory=CurriculumConfig) + + def __post_init__(self) -> None: + if self.group_size < 2: + raise ValueError("GRPO group_size must be >= 2 (GRPO requires groups).") + if self.max_steps < 1: + raise ValueError("max_steps must be >= 1") + if self.seed < 0: + raise ValueError("seed must be >= 0") + if self.lora_r < 1: + raise ValueError("lora_r must be >= 1") + if not 0.0 < self.temperature <= 2.0: + raise ValueError("temperature must be in (0, 2]") + + +__all__ = ["ALL_SCENARIOS", "CurriculumConfig", "GRPOConfig"] diff --git a/training/eval.py b/training/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2e867930c4348b0b1255534b150644ff58856ae2 --- /dev/null +++ b/training/eval.py @@ -0,0 +1,519 @@ +"""Evaluation harness for SQLDrift. + +K-rollouts-per-scenario sweep that emits a markdown ``report.md`` and a +row-level ``per_episode.csv`` so reviewers can eyeball pass rate and +reward distribution without spinning up a notebook. + +Invocation:: + + python -m training.eval \ + --checkpoint base \ + --scenarios 1-10 \ + --seeds-per-scenario 5 \ + --out outputs/evals// + +``--checkpoint base`` runs the CPU :class:`RandomAgent` as a baseline +(no LLM, no GPU). A non-"base" value is a pointer to a saved adapter +and will attempt the lazy Unsloth import path in :func:`load_agent`. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +import statistics +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Protocol + +from models import REWARD_COMPONENT_KEYS, SqlDriftAction, SqlDriftObservation +from server import SqlDriftEnvironment +from training.config import ALL_SCENARIOS +from training.random_agent import RandomAgent +from training.seeding import set_seed +from utilities.logger import get_module_logger + +_LOG = get_module_logger(__name__) + +PASS_REWARD_THRESHOLD = 0.5 +"""An episode is a "pass" when its terminal reward (the step that flips +``obs.done``—the submit step or the last step before budget exhaustion) +meets this threshold. The rubric scores a correct submission at +``+1.0`` in :attr:`r_correct`, so ``0.5`` matches the +"correct-but-not-improved" knee. Per-step values accumulate in +``episode_return`` and are not used for the pass bit.""" + + +# ----------------------------------------------------------------------------- +# Agent interface +# ----------------------------------------------------------------------------- + + +class Agent(Protocol): + """Duck-typed policy — :meth:`RandomAgent.act` fits this shape.""" + + def reset(self, seed: int | None = None, scenario_id: str | None = None) -> None: ... + def act(self, obs: SqlDriftObservation) -> SqlDriftAction: ... + + +def load_agent( + checkpoint: str, + *, + seed: int = 0, + base_model: str | None = None, + temperature: float = 0.0, +) -> Agent: + """Resolve a checkpoint spec to a concrete agent. + + * ``base``/``random`` — CPU-only :class:`RandomAgent` baseline. + * Any other value is treated as a filesystem path (a full + Hugging Face checkpoint directory or a PEFT adapter directory). The + :class:`training.llm_agent.LLMAgent` is imported lazily so + CPU-only CI that never calls ``load_agent`` with a path never + has to install ``transformers``/``peft``. + + ``base_model`` is forwarded to :class:`LLMAgent` when the adapter + directory does not pin its base model; ``temperature=0`` (greedy) + is the default for deterministic eval sweeps. + """ + if checkpoint in ("base", "random"): + return RandomAgent(seed=seed) + + path = Path(checkpoint) + if not path.exists(): + raise FileNotFoundError( + f"checkpoint {checkpoint!r} is not 'base'/'random' and does not exist on disk" + ) + + # Lazy import: keeps the heavy transformers/peft tree out of + # the default import surface for ``training.eval``. + from training.llm_agent import LLMAgent + + return LLMAgent( + str(path), + base_model=base_model, + temperature=temperature, + seed=seed, + ) + + +# ----------------------------------------------------------------------------- +# Episode runner +# ----------------------------------------------------------------------------- + + +@dataclass +class EpisodeResult: + scenario_id: str + seed: int + terminal_reward: float + episode_return: float + steps: int + passed: bool + submitted: bool + drift_fired: bool + wall_ms: float + reward_components: dict[str, float] = field(default_factory=dict) + effective_speedup: float | None = None + + def as_row(self) -> dict[str, str]: + row: dict[str, str] = { + "scenario_id": self.scenario_id, + "seed": str(self.seed), + "terminal_reward": f"{self.terminal_reward:.4f}", + "episode_return": f"{self.episode_return:.4f}", + "steps": str(self.steps), + "passed": "1" if self.passed else "0", + "submitted": "1" if self.submitted else "0", + "drift_fired": "1" if self.drift_fired else "0", + "wall_ms": f"{self.wall_ms:.2f}", + "effective_speedup": _format_speedup(self.effective_speedup), + } + for k in REWARD_COMPONENT_KEYS: + row[k] = f"{self.reward_components.get(k, 0.0):.4f}" + return row + + +def _format_speedup(value: float | None) -> str: + """Render an effective_speedup cell for CSV output. + + ``None`` (no submission) and ``+∞`` (baseline invalidated by drift) + need distinct, non-numeric representations so a downstream parser + cannot conflate "no data" with "infinite" — both read back as empty + cells today, which would bias per-scenario means. + """ + if value is None: + return "" + if math.isinf(value): + return "inf" + return f"{value:.3f}" + + +def _effective_speedup(env: SqlDriftEnvironment) -> float | None: + """Read the current episode's effective speedup through the env surface. + + Kept as a thin shim so tests and reporting code don't need to reach + into env internals themselves. + """ + return env.effective_speedup() + + +def _run_one_episode( + env: SqlDriftEnvironment, + agent: Agent, + *, + scenario_id: str, + seed: int, + max_steps: int = 30, +) -> EpisodeResult: + t0 = time.perf_counter() + obs = env.reset(seed=seed, scenario_id=scenario_id) + agent.reset(seed=seed, scenario_id=scenario_id) + + episode_return = 0.0 + terminal_reward = 0.0 + last_components: dict[str, float] = {} + steps = 0 + while not obs.done and steps < max_steps: + action = agent.act(obs) + obs = env.step(action) + if obs.reward is not None: + episode_return += obs.reward + terminal_reward = obs.reward + if obs.reward_components: + last_components = dict(obs.reward_components) + steps += 1 + + state = env.state + wall_ms = (time.perf_counter() - t0) * 1000.0 + + return EpisodeResult( + scenario_id=scenario_id, + seed=seed, + terminal_reward=terminal_reward, + episode_return=episode_return, + steps=steps, + passed=terminal_reward >= PASS_REWARD_THRESHOLD, + submitted=state.submitted, + drift_fired=state.drift_fired, + wall_ms=wall_ms, + reward_components=last_components, + effective_speedup=_effective_speedup(env), + ) + + +def run_eval( + *, + checkpoint: str, + scenarios: list[str], + seeds_per_scenario: int, + out_dir: Path, + max_steps: int = 30, + base_seed: int = 0, + progress_cb: Callable[[int, int], None] | None = None, + base_model: str | None = None, + temperature: float = 0.0, +) -> dict[str, Any]: + """Drive the full sweep. Returns the summary dict also written to JSON. + + ``base_model`` / ``temperature`` are forwarded to + :func:`load_agent` for LLM-checkpoint runs (ignored for + ``base``/``random``). + """ + out_dir.mkdir(parents=True, exist_ok=True) + set_seed(base_seed) + agent = load_agent( + checkpoint, + seed=base_seed, + base_model=base_model, + temperature=temperature, + ) + + results: list[EpisodeResult] = [] + env = SqlDriftEnvironment() + total = len(scenarios) * seeds_per_scenario + done = 0 + try: + for scenario_id in scenarios: + for k in range(seeds_per_scenario): + seed = base_seed + k + res = _run_one_episode( + env, + agent, + scenario_id=scenario_id, + seed=seed, + max_steps=max_steps, + ) + results.append(res) + done += 1 + if progress_cb: + progress_cb(done, total) + finally: + env.close() + + _write_per_episode_csv(results, out_dir / "per_episode.csv") + summary = _build_summary( + results, + checkpoint=checkpoint, + scenarios=scenarios, + seeds_per_scenario=seeds_per_scenario, + ) + (out_dir / "summary.json").write_text(json.dumps(summary, indent=2)) + (out_dir / "report.md").write_text(render_report(summary, results)) + return summary + + +# ----------------------------------------------------------------------------- +# Reporting +# ----------------------------------------------------------------------------- + + +def _write_per_episode_csv(results: list[EpisodeResult], path: Path) -> None: + if not results: + path.write_text("") + return + fieldnames = list(results[0].as_row().keys()) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in results: + writer.writerow(r.as_row()) + + +def _speedup_stats( + speedups: list[float], +) -> tuple[float | None, int]: + """Finite-mean + ``+∞`` count for a list of effective_speedup values. + + ``statistics.fmean`` on a list containing ``+∞`` returns ``+∞``, + which poisons JSON output (``json.dumps`` rejects ``Infinity`` by + default) and washes out the finite distribution. We split the two + channels: the reported mean is over *finite* speedups only, and the + ``infinite_count`` is surfaced separately so reviewers can see how + often drift invalidated the baseline. + """ + finite = [s for s in speedups if not math.isinf(s)] + infinite_count = len(speedups) - len(finite) + mean = statistics.fmean(finite) if finite else None + return mean, infinite_count + + +def _build_summary( + results: list[EpisodeResult], + *, + checkpoint: str, + scenarios: list[str], + seeds_per_scenario: int, +) -> dict[str, Any]: + by_scenario: dict[str, dict[str, Any]] = {} + for sid in scenarios: + scoped = [r for r in results if r.scenario_id == sid] + if not scoped: + continue + terminals = [r.terminal_reward for r in scoped] + returns = [r.episode_return for r in scoped] + speedups = [r.effective_speedup for r in scoped if r.effective_speedup is not None] + mean_sp, inf_sp = _speedup_stats(speedups) + by_scenario[sid] = { + "n": len(scoped), + "pass_rate": sum(1 for r in scoped if r.passed) / len(scoped), + "mean_terminal_reward": statistics.fmean(terminals), + "std_terminal_reward": statistics.pstdev(terminals) if len(terminals) > 1 else 0.0, + "mean_episode_return": statistics.fmean(returns), + "submit_rate": sum(1 for r in scoped if r.submitted) / len(scoped), + "mean_effective_speedup": mean_sp, + "infinite_speedup_count": inf_sp, + } + + all_speedups = [r.effective_speedup for r in results if r.effective_speedup is not None] + mean_sp_all, inf_sp_all = _speedup_stats(all_speedups) + overall = { + "checkpoint": checkpoint, + "n_episodes": len(results), + "seeds_per_scenario": seeds_per_scenario, + "pass_rate": sum(1 for r in results if r.passed) / len(results) if results else 0.0, + "mean_terminal_reward": ( + statistics.fmean(r.terminal_reward for r in results) if results else 0.0 + ), + "std_terminal_reward": ( + statistics.pstdev([r.terminal_reward for r in results]) if len(results) > 1 else 0.0 + ), + "mean_episode_return": ( + statistics.fmean(r.episode_return for r in results) if results else 0.0 + ), + "submit_rate": sum(1 for r in results if r.submitted) / len(results) if results else 0.0, + "mean_effective_speedup": mean_sp_all, + "infinite_speedup_count": inf_sp_all, + } + return {"overall": overall, "by_scenario": by_scenario} + + +def render_report( + summary: dict[str, Any], + results: list[EpisodeResult], +) -> str: + """Compose a reviewer-friendly ``report.md`` string.""" + overall = summary["overall"] + lines: list[str] = [] + lines.append("# SQLDrift evaluation report") + lines.append("") + speedup = overall.get("mean_effective_speedup") + speedup_cell = f"{speedup:.2f}x" if speedup is not None else "—" + lines.append(f"- Checkpoint: `{overall['checkpoint']}`") + lines.append(f"- Episodes: **{overall['n_episodes']}**") + lines.append(f"- Seeds/scenario: {overall['seeds_per_scenario']}") + lines.append( + f"- Overall pass rate (terminal reward ≥ {PASS_REWARD_THRESHOLD}): " + f"**{overall['pass_rate']:.1%}**" + ) + lines.append( + f"- Mean terminal reward: **{overall['mean_terminal_reward']:.3f}** " + f"(σ = {overall['std_terminal_reward']:.3f})" + ) + lines.append(f"- Mean episode return: {overall['mean_episode_return']:.3f}") + lines.append(f"- Submit rate: {overall['submit_rate']:.1%}") + lines.append(f"- Mean effective speedup (finite, submitted episodes): **{speedup_cell}**") + inf_count = overall.get("infinite_speedup_count", 0) + if inf_count: + lines.append(f"- Infinite-speedup episodes (drift invalidated baseline): **{inf_count}**") + lines.append("") + + lines.append("## Per-scenario") + lines.append("") + lines.append("| Scenario | N | Pass | Terminal μ | Return μ | Submit | Speedup |") + lines.append("|----------|---|------|-----------|----------|--------|---------|") + for sid, row in summary["by_scenario"].items(): + sp = row.get("mean_effective_speedup") + sp_cell = f"{sp:.2f}x" if sp is not None else "—" + lines.append( + f"| `{sid}` | {row['n']} | {row['pass_rate']:.0%} | " + f"{row['mean_terminal_reward']:.3f} | {row['mean_episode_return']:.3f} | " + f"{row['submit_rate']:.0%} | {sp_cell} |" + ) + lines.append("") + + lines.append("## Reward-component bars (mean across episodes)") + lines.append("") + lines.append(_render_component_bars(results)) + lines.append("") + + lines.append("## Notes") + lines.append("") + lines.append( + "- Random baseline does NOT attempt rewrites intelligently; " + "nonzero pass rate here is a lower bound on a well-trained agent." + ) + lines.append( + "- Pass threshold is set at reward ≥ " + f"{PASS_REWARD_THRESHOLD}, matching the rubric's " + '"correct-but-not-improved" +0.5 partial credit (rubric).' + ) + return "\n".join(lines) + "\n" + + +def _render_component_bars(results: list[EpisodeResult]) -> str: + if not results: + return "_no data_" + sums: dict[str, float] = {k: 0.0 for k in REWARD_COMPONENT_KEYS} + counts: dict[str, int] = {k: 0 for k in REWARD_COMPONENT_KEYS} + for r in results: + for k, v in r.reward_components.items(): + if k in sums: + sums[k] += v + counts[k] += 1 + means = {k: (sums[k] / counts[k] if counts[k] else 0.0) for k in sums} + + # ASCII bar with sign (+/-). + max_abs = max((abs(v) for v in means.values()), default=1.0) or 1.0 + width = 30 + lines = ["```"] + for k in REWARD_COMPONENT_KEYS: + v = means[k] + bar_len = int(round(abs(v) / max_abs * width)) + bar = ("█" * bar_len) if v >= 0 else ("▒" * bar_len) + lines.append(f"{k:<14} {v:+7.3f} {bar}") + lines.append("```") + return "\n".join(lines) + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + + +def _expand_scenarios(spec: str) -> list[str]: + """Accept either ``1-10``, ``1,3,5``, or a comma-list of raw ids.""" + if "-" in spec and all(part.isdigit() for part in spec.split("-")): + lo, hi = (int(x) for x in spec.split("-")) + want_range: set[int] = set(range(lo, hi + 1)) + return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_range] + if all(part.strip().isdigit() for part in spec.split(",")): + want_set = {int(part) for part in spec.split(",")} + return [s for s in ALL_SCENARIOS if int(s.split("_", 1)[0]) in want_set] + return [s.strip() for s in spec.split(",")] + + +def main(argv: list[str] | None = None) -> None: + ap = argparse.ArgumentParser(description="Evaluate an agent on SQLDrift.") + ap.add_argument("--checkpoint", required=True, help="'base' or adapter path") + ap.add_argument( + "--scenarios", + default="1-10", + help="Scenario range/list (e.g. '1-10', '1,3,5') or raw ids", + ) + ap.add_argument("--seeds-per-scenario", type=int, default=5) + ap.add_argument("--out", type=Path, required=True) + ap.add_argument("--max-steps", type=int, default=30) + ap.add_argument("--base-seed", type=int, default=0) + ap.add_argument( + "--base-model", + default=None, + help="Override the base model id when --checkpoint points at a PEFT adapter", + ) + ap.add_argument( + "--temperature", + type=float, + default=0.0, + help="Decoding temperature for LLM checkpoints (0 = greedy / deterministic).", + ) + ns = ap.parse_args(argv) + + scenarios = _expand_scenarios(ns.scenarios) + if not scenarios: + raise SystemExit(f"no scenarios matched spec {ns.scenarios!r}") + + def _prog(done: int, total: int) -> None: + if done == total or done % max(1, total // 10) == 0: + _LOG.info("eval: %d/%d episodes", done, total) + + summary = run_eval( + checkpoint=ns.checkpoint, + scenarios=scenarios, + seeds_per_scenario=ns.seeds_per_scenario, + out_dir=ns.out, + max_steps=ns.max_steps, + base_seed=ns.base_seed, + progress_cb=_prog, + base_model=ns.base_model, + temperature=ns.temperature, + ) + print(json.dumps(summary["overall"], indent=2)) + + +if __name__ == "__main__": + main() + + +__all__ = [ + "Agent", + "EpisodeResult", + "PASS_REWARD_THRESHOLD", + "load_agent", + "main", + "render_report", + "run_eval", +] diff --git a/training/evidence/grpo_components_combined.png b/training/evidence/grpo_components_combined.png new file mode 100644 index 0000000000000000000000000000000000000000..1fa3168dd1b38dafa09b17ef570b68eb3da631ab --- /dev/null +++ b/training/evidence/grpo_components_combined.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48462be701ddc14eb7fd4933a9ad1c94c9b5a3692169e9d98dc56ab12a78baaa +size 231592 diff --git a/training/evidence/grpo_loss_curve.png b/training/evidence/grpo_loss_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..8d46cbcf44c0f79d62c131edb4101e5981cc449c Binary files /dev/null and b/training/evidence/grpo_loss_curve.png differ diff --git a/training/evidence/grpo_metrics.csv b/training/evidence/grpo_metrics.csv new file mode 100644 index 0000000000000000000000000000000000000000..c44070a7256482ea14ff35000f6fd83aa5d0d93f --- /dev/null +++ b/training/evidence/grpo_metrics.csv @@ -0,0 +1,21 @@ +step,loss,reward,r_correct,r_drift,r_step_tax +1,1.5070857702340332,-1.089657810470115,0.0060142775960789485,-0.05147482162696701,-0.5153269778837614 +2,1.4163939988607464,-0.8820803512725144,0.03844503923645059,-0.10545186822296794,-0.4861222451750179 +3,1.4296624851767596,-0.9146981630547505,0.10331660081021651,0.07228358802913387,-0.4961777953169654 +4,1.3142921668473637,-0.5595880667701431,0.08344587035631848,0.059823292014677544,-0.520072970343028 +5,1.2701518625957149,-0.5996319117518063,0.2610718206538336,0.004943992480230808,-0.4840739746024923 +6,1.2640116521205307,-0.8258103297724467,0.26761100025304996,0.1913788646035109,-0.49710482987102533 +7,1.1209233460223358,-0.6211756254418923,0.27404383444577135,0.200216496595977,-0.4682072435811402 +8,1.1774670143175638,-0.619087388960126,0.270469688082585,0.17365149920126927,-0.4864816492894405 +9,1.1983245069119268,-0.5491623820497955,0.3851815028189252,0.14236176514396473,-0.47340387335667083 +10,1.1900131680220651,-0.4531450217747221,0.46035015896826725,0.20935614488999046,-0.3977678494548791 +11,1.0778750024182207,-0.5529804804536397,0.4784254548477095,0.25126495579288227,-0.4914229994698512 +12,1.0142047885469356,-0.30198009788159785,0.5200582520058386,0.24302822958435058,-0.46425248430094396 +13,1.0045646774653714,-0.4645827945838715,0.5862339331205229,0.27468662267016536,-0.4534963800694839 +14,0.9116167040003166,-0.3810888813065016,0.5319641343757862,0.2831807948097964,-0.43758017046734105 +15,0.9119484849618256,-0.19031109272031552,0.544487836273619,0.30242310302170133,-0.43358305018602905 +16,0.8631872796395779,-0.1902270628270999,0.5378009235565281,0.46520136753468744,-0.4378038803670464 +17,0.9108494430815297,-0.28522550174067374,0.6317914174443212,0.39742825226075135,-0.4194535022660978 +18,0.7414210838734385,-0.026510480876009573,0.6298982943138216,0.49275288649025134,-0.4263594599419731 +19,0.8025799827262751,-0.09087146161792733,0.69004717240886,0.4924557349195555,-0.3603965798497102 +20,0.6973776508874913,0.011148493528077183,0.7800312956412032,0.456173588246084,-0.4150265290706628 diff --git a/training/evidence/grpo_r_correct_curve.png b/training/evidence/grpo_r_correct_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..25bce38b6ff31ea443b7efe61e3165e480cb29c1 Binary files /dev/null and b/training/evidence/grpo_r_correct_curve.png differ diff --git a/training/evidence/grpo_r_drift_curve.png b/training/evidence/grpo_r_drift_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..7b845ec63f6f0e8e198d04fbf95b2faadebe651f Binary files /dev/null and b/training/evidence/grpo_r_drift_curve.png differ diff --git a/training/evidence/grpo_r_step_tax_curve.png b/training/evidence/grpo_r_step_tax_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..70a58fbfdb6b1a7d9bc9c6f066ed9ff7a34bab42 Binary files /dev/null and b/training/evidence/grpo_r_step_tax_curve.png differ diff --git a/training/evidence/grpo_reward_curve.png b/training/evidence/grpo_reward_curve.png new file mode 100644 index 0000000000000000000000000000000000000000..0429501f774d801d29429f9dbf9e237b73fdf92b --- /dev/null +++ b/training/evidence/grpo_reward_curve.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e043085089726b5d6b7b648c742d6102a5c3c6ee25f5c97eb2eeb1e894a225b +size 106759 diff --git a/training/grpo_train.py b/training/grpo_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6adfed318e9fca71b248e2d37452b3696f8c07b3 --- /dev/null +++ b/training/grpo_train.py @@ -0,0 +1,329 @@ +"""GRPO training entrypoint wiring for SQLDrift. + +This module hosts the building blocks for training a Qwen2.5-class +model against the SQLDrift OpenEnv environment with +:class:`trl.GRPOTrainer`: + +* :func:`iter_curriculum` — pure, dependency-free scenario sampler + used both by :func:`build_dataset` and by unit tests. +* :func:`build_dataset` — turns the curriculum iterator into a + Hugging Face :class:`datasets.Dataset` whose rows carry the per- + episode seed / scenario id that :class:`.tool_env.SqlDriftToolEnv` + consumes in its ``reset(**kwargs)``. +* :func:`load_model_and_tokenizer` — lazy Unsloth loader that keeps + the CUDA dependency tree out of the module top-level so CPU-only + CI can still import this file. +* :func:`build_env_client` — sanctioned way to obtain the OpenEnv + client bound to a running SQLDrift server (used by the eval + harness and notebooks — the trainer builds its own clients via + ``environment_factory``). +* :func:`reward_from_environments` — a TRL-compatible reward + function that reads the cumulative trajectory return off each + rolled-out :class:`SqlDriftToolEnv` instance. +* :func:`train` — the real entrypoint: builds a curriculum dataset, + loads the model, instantiates :class:`trl.GRPOTrainer` with + an ``environment_factory`` bound to :class:`SqlDriftToolEnv` and the + caller-supplied env URL (the TRL-sanctioned multi-turn OpenEnv + rollout path, see the TRL OpenEnv integration guide) and runs + ``trainer.train()``. All heavy imports are lazy so CPU-only CI can + still import the module. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator, Sequence +from dataclasses import asdict +from functools import partial +from pathlib import Path +from random import Random +from typing import Any + +from training.config import GRPOConfig +from training.prompt import render_system_prompt +from training.seeding import set_seed +from training.tool_env import SqlDriftToolEnv +from utilities.logger import get_module_logger + +_LOG = get_module_logger(__name__) + + +def iter_curriculum(config: GRPOConfig, *, seed: int = 0) -> Iterator[tuple[str, int]]: + """Yield an infinite stream of ``(scenario_id, episode_seed)`` tuples.""" + rng = Random(seed) + curr = config.curriculum + lo, hi = curr.seed_range + i = 0 + while True: + if curr.mode == "uniform": + scenario = rng.choice(curr.scenarios) + elif curr.mode == "weighted": + scenario = rng.choices(curr.scenarios, weights=list(curr.weights or ()), k=1)[0] + else: + scenario = curr.scenarios[i % len(curr.scenarios)] + yield scenario, rng.randint(lo, hi - 1) + i += 1 + + +def build_dataset(config: GRPOConfig, *, num_rows: int, seed: int = 0) -> Any: + """Build a :class:`datasets.Dataset` of prompt rows for GRPO. + + Every row pre-computes the system prompt for a single ``(scenario, + seed)`` pair so the trainer sees a normal chat-format ``prompt`` + column. The extra columns (``scenario_id``, ``seed``, + ``budget_steps``, ``enable_dba_oracle``) ride along verbatim and are forwarded by TRL as + ``**kwargs`` to :meth:`SqlDriftToolEnv.reset`, which is how we + reproducibly pin each episode to its curriculum slot. + + This function imports :mod:`datasets` lazily to keep the stdlib- + only import surface for CPU-only CI. + """ + if num_rows < 1: + raise ValueError("build_dataset requires num_rows >= 1") + + from datasets import Dataset + + it = iter_curriculum(config, seed=seed) + prompts: list[list[dict[str, str]]] = [] + scenario_ids: list[str] = [] + seeds: list[int] = [] + budgets: list[int] = [] + oracle_flags: list[bool] = [] + + for _ in range(num_rows): + scenario_id, episode_seed = next(it) + system = render_system_prompt( + scenario_id=scenario_id, dba_enabled=config.dba_oracle_enabled + ) + prompts.append( + [ + {"role": "system", "content": system}, + { + "role": "user", + "content": ( + "Diagnose, adapt to any drift, and submit a correct " + "rewrite of the baseline query using the tools " + "provided. Call submit_rewrite when confident." + ), + }, + ] + ) + scenario_ids.append(scenario_id) + seeds.append(episode_seed) + budgets.append(config.episode_step_budget) + oracle_flags.append(config.dba_oracle_enabled) + + return Dataset.from_dict( + { + "prompt": prompts, + "scenario_id": scenario_ids, + "seed": seeds, + "budget_steps": budgets, + "enable_dba_oracle": oracle_flags, + } + ) + + +def load_model_and_tokenizer(config: GRPOConfig) -> tuple[Any, Any]: + """Import Unsloth + attach LoRA. Requires the ``[train]`` extra on GPU.""" + from unsloth import FastLanguageModel + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=config.model_name, + max_seq_length=config.max_seq_length, + load_in_4bit=config.load_in_4bit, + ) + model = FastLanguageModel.get_peft_model( + model, + r=config.lora_r, + target_modules=list(config.lora_target_modules), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + use_gradient_checkpointing="unsloth", + ) + return model, tokenizer + + +def build_env_client(config: GRPOConfig) -> Any: + """Instantiate the OpenEnv client bound to the running SQLDrift server. + + The GRPO trainer itself does not use this helper — it owns its own + client lifecycle via ``environment_factory=SqlDriftToolEnv``. The + helper exists for the eval harness and notebooks that still want a + one-shot client. + """ + from client import SqlDriftEnv + + return SqlDriftEnv(base_url=config.env_base_url).sync() + + +def reward_from_environments( + environments: Sequence[SqlDriftToolEnv], + **_: Any, +) -> list[float]: + """TRL-compatible reward function. + + SQLDrift's reward shaping is trajectory-based: step tax, tool-error + penalties, repeat-failing-query penalties, and DBA consult penalties + accrue *before* the final submit step. GRPO therefore needs the + running :attr:`SqlDriftToolEnv.episode_return`, not just the last + step's reward, or it would silently discard most of the shaping + signal during training. + """ + return [float(env.episode_return) for env in environments] + + +def _build_flush_log_history_callback(out_path: Path) -> Any: + """Construct a TrainerCallback that appends each log dict to JSONL. + + Why: `trainer.state.log_history` is in-memory only; a crash mid-run + wipes the curves needed by `utilities/plot_curves.py`. This callback + flushes per `on_log` so a step-N crash still leaves N records on + disk for the post-mortem plot. + + The callback class is constructed lazily inside `train()` because + `transformers.TrainerCallback` is a [train]-extra import. + """ + from transformers import TrainerCallback + + class _FlushLogHistory(TrainerCallback): + def __init__(self, target: Path) -> None: + self._target = target + self._target.parent.mkdir(parents=True, exist_ok=True) + + def on_log( + self, + args: Any, + state: Any, + control: Any, + logs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + if not logs: + return + record = {"step": state.global_step, **logs} + with self._target.open("a") as f: + f.write(json.dumps(record, default=str) + "\n") + + return _FlushLogHistory(out_path) + + +def train(config: GRPOConfig) -> Any: + """Run the full GRPO training loop against SQLDrift. + + All heavy imports (``trl``, ``datasets``, ``transformers``, + ``unsloth``) are deferred into this function body. A CPU-only + checkout that only imports :mod:`training.grpo_train` for, say, + :func:`iter_curriculum` never triggers them. + + The function is covered by the ``[train]`` extra (see + ``pyproject.toml``) and is the path exercised in the Colab + notebook bundled with the hackathon submission. + """ + set_seed(config.seed) + out = Path(config.output_dir) + out.mkdir(parents=True, exist_ok=True) + (out / "grpo_config.json").write_text( + json.dumps(asdict(config), default=str, indent=2), + ) + + from trl import GRPOConfig as TRLGRPOConfig + from trl import GRPOTrainer + + dataset = build_dataset( + config, + num_rows=max(config.max_steps * config.group_size, config.group_size), + seed=config.seed, + ) + + trl_args = TRLGRPOConfig( + output_dir=str(out), + learning_rate=config.learning_rate, + max_steps=config.max_steps, + per_device_train_batch_size=config.group_size, + num_generations=config.group_size, + gradient_accumulation_steps=config.gradient_accumulation_steps, + warmup_steps=config.warmup_steps, + max_prompt_length=config.max_prompt_length, + max_completion_length=config.max_completion_length, + temperature=config.temperature, + top_p=config.top_p, + logging_steps=config.logging_steps, + save_steps=config.save_steps, + log_completions=True, + report_to=["tensorboard"], + logging_dir=str(out / "tb"), + ) + + model, tokenizer = load_model_and_tokenizer(config) + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + train_dataset=dataset, + reward_funcs=reward_from_environments, + args=trl_args, + environment_factory=partial(SqlDriftToolEnv, env_url=config.env_base_url), + ) + trainer.add_callback(_build_flush_log_history_callback(out / "log_history.jsonl")) + + _LOG.info( + "Starting GRPO: scenarios=%d, max_steps=%d, group_size=%d, env_url=%s", + len(config.curriculum.scenarios), + config.max_steps, + config.group_size, + config.env_base_url, + ) + trainer.train() + trainer.save_model(str(out)) + return trainer + + +def _parse_args(argv: list[str] | None = None) -> GRPOConfig: + import argparse + + ap = argparse.ArgumentParser(description="GRPO training for SQLDrift.") + ap.add_argument( + "--config", + type=Path, + default=None, + help="Optional JSON manifest overriding GRPOConfig defaults.", + ) + ap.add_argument("--max-steps", type=int, default=None) + ap.add_argument("--output-dir", type=str, default=None) + ap.add_argument("--env-base-url", type=str, default=None) + ns = ap.parse_args(argv) + + overrides: dict[str, Any] = {} + if ns.config is not None: + overrides.update(json.loads(ns.config.read_text())) + if ns.max_steps is not None: + overrides["max_steps"] = ns.max_steps + if ns.output_dir is not None: + overrides["output_dir"] = ns.output_dir + if ns.env_base_url is not None: + overrides["env_base_url"] = ns.env_base_url + from training.config import CurriculumConfig + + if "curriculum" in overrides and isinstance(overrides["curriculum"], dict): + overrides["curriculum"] = CurriculumConfig(**overrides["curriculum"]) + return GRPOConfig(**overrides) + + +def main(argv: list[str] | None = None) -> None: + cfg = _parse_args(argv) + train(cfg) + + +if __name__ == "__main__": + main() + + +__all__ = [ + "build_dataset", + "build_env_client", + "iter_curriculum", + "load_model_and_tokenizer", + "reward_from_environments", + "train", +] diff --git a/training/llm_agent.py b/training/llm_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..2fff2a897d66db6a4f12bd1f568b2bfb02ed7a2c --- /dev/null +++ b/training/llm_agent.py @@ -0,0 +1,380 @@ +"""LLM-backed agent used by :mod:`training.eval` for trained checkpoints. + +The eval harness only needs the :class:`.eval.Agent` protocol +(``reset(seed)`` + ``act(obs) -> SqlDriftAction``). This module +supplies a minimal, chat-template-driven policy that: + +1. Loads a saved model directory (either a full HF checkpoint or a PEFT + adapter pointing at a base model). +2. Maintains a bounded chat history across the episode so the model + sees its own prior tool calls and their observations. +3. Prompts the model to emit *exactly one* JSON tool-call envelope per + turn (``{"tool": "...", "payload": {...}}``) and parses it into a + :class:`models.SqlDriftAction`. +4. Falls back to a safe default action on parse failure rather than + crashing the rollout — this matches the random-agent contract and + keeps eval sweeps resilient to occasional generation noise. + +All heavy ML imports (``torch``, ``transformers``, ``peft``) are +deferred into :meth:`LLMAgent.__init__` so the module is importable on +CPU-only CI boxes for type checking. +""" + +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Any + +from models import ( + ConsultDBAResult, + DescribeTableResult, + ExplainQueryResult, + ListTablesPayload, + ListTablesResult, + ReadChangelogResult, + RunQueryResult, + SampleRowsResult, + SqlDriftAction, + SqlDriftObservation, + SubmitRewriteResult, + ToolError, + ToolName, + ToolPayload, +) +from training.prompt import render_system_prompt +from utilities.logger import get_module_logger, log_interaction + +_LOG = get_module_logger(__name__) + +# Compact, model-facing JSON contract. Kept short because it ships with +# every turn and its tokens count against ``max_seq_length``. +_TOOL_CONTRACT = ( + "Respond with EXACTLY ONE JSON object per turn and nothing else:\n" + '{"tool": "", "payload": {...}}\n' + "Valid tool names: list_tables, describe_table, sample_rows, run_query, " + "explain_query, read_changelog, submit_rewrite, consult_dba.\n" + "Payload schemas (match one):\n" + '- list_tables: {"kind": "list_tables"}\n' + '- describe_table: {"kind": "describe_table", "table": ""}\n' + '- sample_rows: {"kind": "sample_rows", "table": "", "limit": 1..5}\n' + '- run_query: {"kind": "run_query", "sql": ""}\n' + '- read_changelog: {"kind": "read_changelog"}\n' + '- submit_rewrite: {"kind": "submit_rewrite", "sql": "