SQLSHERLOCK-ENV / train.py
Swethaditya's picture
Initial commit
ce59113
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
SQLSherlock-Env β€” TRL GRPO Training Script.
Fine-tunes a language model via Group Relative Policy Optimisation (GRPO)
using the SQLSherlock RL environment as the reward signal.
The model learns the data-scientist investigation workflow:
profile β†’ hypothesise β†’ fix β†’ validate β†’ export
Environment variables:
SPACE_URL β€” SQLSherlock server URL (default: http://localhost:7860)
MODEL_ID β€” Base model to fine-tune (default: Qwen/Qwen2.5-1.5B-Instruct)
DATASET_NAME β€” Training dataset (default: mstz/titanic)
OUTPUT_DIR β€” Checkpoint output dir (default: ./grpo_output)
NUM_STEPS β€” Training steps (default: 200)
BATCH_SIZE β€” Batch size (default: 4)
TASK_ID β€” Task to train on (default: task1_null_and_types)
"""
import os
import sys
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
SPACE_URL = os.environ.get("SPACE_URL", "http://localhost:7860")
MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
DATASET_NAME = os.environ.get("DATASET_NAME", "phihung/titanic")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./grpo_output")
NUM_STEPS = int(os.environ.get("NUM_STEPS", "200"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
TASK_ID = os.environ.get("TASK_ID", "task1_null_and_types")
# ---------------------------------------------------------------------------
# GRPO Environment wrapper
# ---------------------------------------------------------------------------
class SQLSherlockGRPOEnv:
"""Thin wrapper around SQLSherlockEnv exposing tool-call methods.
Each method corresponds to one action type. TRL's GRPO trainer
calls reset() to start an episode, then the model calls methods
as tool calls. The cumulative reward is read via reward_func().
"""
def __init__(self) -> None:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "sqlsherlock_env"))
from client import SQLSherlockEnv
self._env_class = SQLSherlockEnv
self._client = None
self.reward = 0.0
self._primary_table: str = "dataset"
def _client_or_create(self):
if self._client is None:
self._client = self._env_class(base_url=SPACE_URL)
return self._client
def reset(self, **kwargs) -> str:
"""Reset the environment and return a string observation.
Args:
dataset (str): HuggingFace dataset name or file path.
task_id (str): Task identifier string.
"""
from client import SQLSherlockEnv
# Fresh client each episode for isolation
try:
if self._client is not None:
self._client.close()
except Exception:
pass
self._client = SQLSherlockEnv(base_url=SPACE_URL)
dataset = kwargs.get("dataset", DATASET_NAME)
task_id = kwargs.get("task_id", TASK_ID)
obs = self._client.reset(task_id=task_id, dataset=dataset)
self._primary_table = list(obs.tables_summary.keys())[0]
self.reward = 0.0
return (
f"Table: {self._primary_table}\n"
f"Columns: {obs.tables_summary[self._primary_table]['columns']}\n"
f"Rows: {obs.tables_summary[self._primary_table]['row_count']}\n"
f"Task: {obs.task_description}"
)
def inspect_table(self, table: str) -> str:
"""View all rows in a database table.
Args:
table: Name of the table to inspect.
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(action_type="inspect", table=table)
)
self.reward += r
return obs.last_feedback
def profile_column(self, table: str, column: str) -> str:
"""Get statistical profile: mean, std, min, max, null_count, z-scores.
Args:
table: Table name containing the column.
column: Column name to profile statistically.
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(
action_type="profile_column", table=table, column=column
)
)
self.reward += r
return obs.last_feedback
def run_query(self, sql: str) -> str:
"""Execute a SELECT SQL query to find data quality issues.
Args:
sql: A SELECT SQL query string. No write operations allowed.
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(action_type="run_sql", sql=sql)
)
self.reward += r
return obs.last_feedback
def fix_cell(
self,
table: str,
row_id: int,
column: str,
value: str,
reason: str,
) -> str:
"""Fix a data quality issue in one cell.
Args:
table: Table name.
row_id: Row primary key.
column: Column to fix.
value: The corrected value to write.
reason: Statistical justification for this fix (e.g. z-score, median).
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(
action_type="fix_cell",
table=table,
row_id=row_id,
column=column,
value=value,
reason=reason,
)
)
self.reward += r
return obs.last_feedback
def delete_row(self, table: str, row_id: int, reason: str) -> str:
"""Delete a duplicate or FK-violation row.
Args:
table: Table name.
row_id: Row primary key to delete.
reason: Why this row should be removed (e.g. duplicate key detected).
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(
action_type="delete_row",
table=table,
row_id=row_id,
reason=reason,
)
)
self.reward += r
return obs.last_feedback
def validate(self) -> str:
"""Run all 6 validation checks comparing cleaned vs raw data.
Call this after making fixes to verify your work is correct.
Returns pass/fail status for each check.
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(action_type="validate")
)
self.reward += r
return obs.last_feedback
def submit(self) -> str:
"""Submit the investigation for final scoring.
Call only when you have fixed all discovered issues and
validate() shows improvement.
"""
from models import SQLSherlockAction
obs, r, done, _ = self._client_or_create().step(
SQLSherlockAction(action_type="submit")
)
self.reward += r
last = obs.reward_trace[-1] if obs.reward_trace else {}
return f"Final reward: {last.get('total', 0.0):.4f}"
# ---------------------------------------------------------------------------
# GRPO reward function
# ---------------------------------------------------------------------------
def reward_func(environments: list, **kwargs) -> list[float]:
"""Return cumulative episode reward for each environment.
Called by TRL's GRPOTrainer after each rollout batch.
Args:
environments: List of SQLSherlockGRPOEnv instances.
Returns:
List of float rewards, one per environment.
"""
return [env.reward for env in environments]
# ---------------------------------------------------------------------------
# Training entry point
# ---------------------------------------------------------------------------
def main() -> None:
try:
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
except ImportError:
print(
"Training dependencies not installed.\n"
"Install with: pip install 'sqlsherlock-env[train]'\n"
" or: pip install trl transformers torch"
)
sys.exit(1)
print(f"SQLSherlock GRPO Training")
print(f" Model : {MODEL_ID}")
print(f" Dataset : {DATASET_NAME}")
print(f" Task : {TASK_ID}")
print(f" Steps : {NUM_STEPS}")
print(f" Output : {OUTPUT_DIR}")
print(f" Server : {SPACE_URL}")
print()
# Load model and tokenizer
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Build a minimal training prompt dataset
# The model generates tool calls; the environment provides rewards
training_prompts = [
{
"prompt": (
"You are a data scientist. Investigate the dataset for quality issues.\n"
f"Dataset: {DATASET_NAME}\n"
f"Task: {TASK_ID}\n"
"Use the available tools: inspect_table, profile_column, run_query, "
"fix_cell, delete_row, validate, submit.\n"
"Start by inspecting the table."
)
}
for _ in range(max(BATCH_SIZE * 4, 16))
]
# GRPO configuration
grpo_config = GRPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=1,
max_steps=NUM_STEPS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
save_steps=50,
num_generations=BATCH_SIZE,
max_new_tokens=256,
temperature=0.7,
report_to="none",
)
# Instantiate environments (one per generation slot)
environments = [SQLSherlockGRPOEnv() for _ in range(BATCH_SIZE)]
# Build tools list for the trainer
tools = [
environments[0].inspect_table,
environments[0].profile_column,
environments[0].run_query,
environments[0].fix_cell,
environments[0].delete_row,
environments[0].validate,
environments[0].submit,
]
print("Starting GRPO training...")
trainer = GRPOTrainer(
model=model,
args=grpo_config,
tokenizer=tokenizer,
train_dataset=training_prompts,
reward_funcs=reward_func,
env=environments,
tools=tools,
)
trainer.train()
print(f"\nTraining complete. Checkpoints saved to: {OUTPUT_DIR}")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Final model saved to: {OUTPUT_DIR}")
if __name__ == "__main__":
main()