sql-drift-env / README.md
visheshrathi's picture
Upload folder using huggingface_hub
bbf206f verified
metadata
title: SQLDrift
emoji: 🐘
colorFrom: blue
colorTo: indigo
sdk: docker
app_port: 8000
short_description: SQL repair and optimization under live schema drift
base_path: /web

SQLDrift

An OpenEnv gym that teaches an LLM agent to repair and optimize SQL when the database schema and business rules shift out from under it.

Why this problem exists

Production analytics breaks in familiar ways: a deploy renames a column, splits an enum into two fields, or changes how dates are stored. The dashboard query that worked on Monday may error or return wrong results on Tuesday. LLMs are a natural fitβ€”they can read changelogs, run diagnostics, and propose rewritesβ€”but static SQL benchmarks assume a fixed schema. They cannot reward an agent for noticing that the world changed during a task or for trading correctness against latency under a tool budget.

SQLDrift closes that gap: each episode is a small DuckDB world, a slow baseline SELECT, a fixed tool budget, and (for drift scenarios) a mid-episode schema or rule change the agent must recover from before submitting a rewrite that is both semantically correct and materially faster than the baseline.

Live environment (Hugging Face Space)

Run the OpenEnv server without cloning the repo:

SQLDrift on Hugging Face Spaces

For TRL / client rollouts, point SQL_DRIFT_ENV_URL (or your SqlDriftEnv / SqlDriftToolEnv base URL) at the Space runtime, for example https://visheshrathi-sql-drift-env.hf.space β€” same host the training notebook health-checks by default. If you need many concurrent WebSocket sessions (e.g. num_generations > 1 in GRPO), duplicate the Space to your own account per the TRL OpenEnv concurrency notes.

SQLDrift is a production-grade OpenEnv environment for RL fine-tuning with TRL’s GRPOTrainer (and Unsloth LoRA on consumer GPUs). Each episode gives the agent a slow baseline query, a lightly populated DuckDB database, read-only introspection and execution tools, and a 25-step budget. Mid-episode, a schema or business-rule drift event may fire; the agent must read the deploy changelog, re-diagnose, and submit a rewrite that is (a) semantically correct and (b) β‰₯ 1.2Γ— faster than the baseline. A hierarchical rubric turns that into six reward signals (correctness, drift-adapt, speedup, step-tax, gatekeepers, consult-dba).

Highlights:

  • 10 hand-crafted scenarios β€” 6 classic anti-patterns (correlated subqueries, SELECT * joins, Cartesian joins, DISTINCT groupbys, nested subqueries, HAVING-as-WHERE) and 4 drift events (column rename, date format flip, enum rule split, field deprecation).
  • Deterministic fixtures β€” 15–30-column schemas, 2–4 tables per scenario, regenerated in-process at reset() from a seed. No Parquet files, no pre-baked sqlite dumps.
  • Sqlglot-canonicalised baseline-verbatim gate β€” agents that rename whitespace and resubmit the baseline cannot farm the +1.0 correct bonus.
  • Skill library β€” 12 pre-seeded playbook/drift-card entries plus an on-disk, fcntl.flock-guarded JSON store that grows as the agent solves episodes, surfaced next reset() via Jaccard retrieval.
  • Feature-flagged DBA Oracle β€” 3-tier escalating hints per scenario, penalised by the ConsultDBA rubric. Off by default.
  • Dockerised β€” server/Dockerfile layers the env on ghcr.io/meta-pytorch/openenv-base and exposes /health, /reset, /step, /ws out-of-the-box.

How the environment works

  1. reset(scenario_id, seed, …) builds a fresh in-memory DuckDB instance, loads the scenario’s tables and rows, and returns an observation with the baseline SQL, phase (DIAGNOSE β†’ …), step budget, and learned_hints from the skill library.
  2. Tooling β€” The agent acts via OpenEnv actions: list/describe tables, sample rows, EXPLAIN, timed read-only run_query, read_changelog, optional consult_dba, and terminal submit_rewrite. Each step consumes budget and accrues shaping reward (e.g. step tax and gatekeeper penalties).
  3. Drift β€” On drift scenarios, an event fires in a configured step window: DDL / rule operators run idempotently (engine/drift.py); the changelog updates; ground truth and baseline behaviour may change; the observation signals drift_recovery so the agent can adapt.
  4. Verification β€” Submissions are checked with an order-independent result hash (engine/verifier.py) and median-of-3 timed execution against the baseline (engine/profiler.py). Correctness and speedup components of the rubric enforce the β‰₯ 1.2Γ— speedup bar for the top correctness tier.
  5. Reward β€” SqlDriftRubric composes six child rubrics; every observation exposes reward_components (e.g. r_correct, r_drift, r_speedup, r_step_tax, r_gatekeepers, r_consult_dba) for analysis and logging.
  6. Server β€” server/app.py uses the OpenEnv FastAPI factory; client.py provides a WebSocket SqlDriftEnv for remote episodes. See openenv.yaml for the Space-oriented manifest.

Results

Random baseline (floor)

We ship a reproducible random-agent evaluation under outputs/evals/baseline_random_v1/ (50 episodes, 5 seeds Γ— 10 scenarios). It establishes a floor for any trained policy:

Metric Value
Mean total reward βˆ’2.048 (Οƒ β‰ˆ 0.70)
Pass rate (reward β‰₯ 0.5) 0%
Submit rate 80%

Mean reward is roughly βˆ’1.8 to βˆ’2.3 per scenario; drift scenarios trend slightly harder than static anti-pattern scenarios. See outputs/evals/baseline_random_v1/report.md for the full per-scenario table and component bars.

Trained policies

GRPO training is driven by training/grpo_train.py and sql_drift_grpo_training.ipynb. After a run, compare checkpoints with:

uv run python -m training.eval \
    --checkpoint path/to/adapter \
    --scenarios 1-10 \
    --seeds-per-scenario 5 \
    --out outputs/evals/my_run

Training logs and plots: training/grpo_train.py can emit step-wise JSONL; utilities/plot_curves.py produces reward/loss figures under training/evidence/ when that log exists.

Quick start

git clone <your-fork>/sql_drift_env.git
cd sql_drift_env
uv sync                                        # install deps
uv run pytest -q                               # 300+ tests, ~90s on CPU
uv run uvicorn server.app:app --reload         # serve the env on :8000

Or via Docker:

docker build -f server/Dockerfile -t sql-drift-env:latest .
docker run -p 8000:8000 sql-drift-env:latest
curl -s http://localhost:8000/health

Programmatic rollout

A five-step in-process episode against scenario 01_correlated_subquery:

from client import SqlDriftEnv
from server import SqlDriftEnvironment

env = SqlDriftEnvironment()
obs = env.reset(seed=42, scenario_id="01_correlated_subquery")
print(obs.learned_hints)             # surfaced from skill library

obs = env.step(SqlDriftEnv.action_list_tables())
obs = env.step(SqlDriftEnv.action_describe_table("users"))
obs = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM users"))
obs = env.step(SqlDriftEnv.action_submit_rewrite(
    "SELECT u.*, COALESCE(c.n, 0) FROM users u "
    "LEFT JOIN (SELECT user_id, COUNT(*) AS n FROM orders GROUP BY 1) c "
    "ON c.user_id = u.user_id"
))
print(obs.reward, obs.reward_components)
env.close()

End-to-end over an HTTP+WS OpenEnv server, see SqlDriftEnv in client.py and the integration test suite in tests/integration/test_client_server.py / tests/integration/test_state_no_leak.py.

Evaluation

uv run python -m training.eval \
    --checkpoint base \
    --scenarios 1-10 \
    --seeds-per-scenario 5 \
    --out outputs/evals/my_run

Emits report.md, per_episode.csv, and summary.json.

Training (GPU)

training/grpo_train.py contains the GRPO entrypoint used by the hackathon training notebook: it builds the curriculum dataset, loads Qwen/Qwen3-1.7B (or your SQL_DRIFT_MODEL_NAME) via transformers.AutoModelForCausalLM + BitsAndBytesConfig 4-bit nf4 (QLoRA), attaches a PEFT LoRA adapter, and lets TRL's GRPOTrainer drive multi-turn OpenEnv rollouts through SqlDriftToolEnv via environment_factory. The install/runtime stack mirrors Hugging Face TRL's reference notebooks (grpo_trl_lora_qlora.ipynb + openenv_wordle_grpo.ipynb). Open sql_drift_grpo_training.ipynb on a free Colab T4, set SQL_DRIFT_ENV_URL to your deployed SQLDrift Space (see Live environment), and run all cells.

uv sync --extra train           # installs trl, transformers, accelerate, unsloth
uv sync --extra evidence        # matplotlib + pandas for utilities/plot_curves.py

Repository layout

sql_drift_env/
β”œβ”€β”€ models.py                   # Pydantic v2 action/observation/state
β”œβ”€β”€ client.py                   # SqlDriftEnv EnvClient (/ws)
β”œβ”€β”€ engine/
β”‚   β”œβ”€β”€ runtime.py              # private RuntimeEpisodeState
β”‚   β”œβ”€β”€ drift.py                # 4 DDL drift operators
β”‚   β”œβ”€β”€ reward.py               # SqlDriftRubric (6 child rubrics)
β”‚   β”œβ”€β”€ verifier.py
|   |-- profiler.py
β”œβ”€β”€ scenarios/                  # 10 hand-crafted scenario modules + registry
β”œβ”€β”€ skill_library/              # pre-seeds + JSON store + Jaccard retrieval
β”œβ”€β”€ actors/                     # engineering_manager (changelog), dba_oracle
β”œβ”€β”€ server/                     # FastAPI app, Dockerfile, env wrapper class
β”œβ”€β”€ training/                   # config, prompt, random_agent, grpo_train, eval
β”œβ”€β”€ utilities/                  # env_loader, logger, plot_curves, run_training_job (HF Jobs), …
β”œβ”€β”€ tests/                      # 300+ unit + integration tests
β”œβ”€β”€ outputs/evals/              # baseline eval artifacts
└── design/                     # design docs (ignored by docker)

References