sql-db-engineer-agent / training /generate_training_data.py
junaid0600's picture
Final Round 2: all checks passing, openenv validate OK
f30d05a
Raw
History Blame Contribute Delete
10.8 kB
"""
training/generate_training_data.py
Generates training data by running episodes on the live environment.
Saves (prompt, action, reward) tuples for GRPO training.
Run this BEFORE train_agent.py to pre-generate data.
"""
import os
import json
import time
import requests
from pathlib import Path
ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "./sdea-trained"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# All Round 2 scenario IDs
ALL_SCENARIOS = [
"easy_s001", "easy_s002", "easy_s003", "easy_s004", "easy_s005",
"medium_s001", "medium_s002", "medium_s003", "medium_s004", "medium_s005",
"hard_s001", "hard_s002", "hard_s003", "hard_s004", "hard_s005",
]
# Optimal action sequences per scenario (expert demonstrations)
EXPERT_ACTIONS = {
"easy_s001": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "analyze_indexes", "payload": {"table": "users"}},
{"action_type": "create_index", "payload": {"table": "users", "columns": ["email"]}},
{"action_type": "submit_report", "payload": {"summary": "Added index on users(email). Email lookup now uses index scan instead of full table scan."}},
],
"easy_s002": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}},
{"action_type": "submit_report", "payload": {"summary": "Composite index on orders(user_id, status) eliminates full table scan."}},
],
"easy_s003": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "create_index", "payload": {"table": "products", "columns": ["name"]}},
{"action_type": "submit_report", "payload": {"summary": "Index on products(name) speeds up LIKE queries."}},
],
"easy_s004": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "create_index", "payload": {"table": "sessions", "columns": ["user_id", "expires_at"]}},
{"action_type": "submit_report", "payload": {"summary": "Composite index on sessions(user_id, expires_at) added."}},
],
"easy_s005": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "create_index", "payload": {"table": "logs", "columns": ["level", "created_at"]}},
{"action_type": "submit_report", "payload": {"summary": "Compound index on logs(level, created_at) added."}},
],
"medium_s001": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "analyze_indexes", "payload": {"table": "orders"}},
{"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}},
{"action_type": "create_index", "payload": {"table": "users", "columns": ["country"]}},
{"action_type": "analyze_statistics", "payload": {"table": "orders"}},
{"action_type": "submit_report", "payload": {"summary": "Two indexes added. Both slow queries now use index scans."}},
],
"medium_s002": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "create_index", "payload": {"table": "posts", "columns": ["author_id", "published", "created_at"]}},
{"action_type": "create_index", "payload": {"table": "authors", "columns": ["username"]}},
{"action_type": "submit_report", "payload": {"summary": "Multi-column index on posts and unique index on authors added."}},
],
"medium_s003": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "create_index", "payload": {"table": "stock_movements", "columns": ["product_id", "movement_type", "created_at"]}},
{"action_type": "rewrite_query", "payload": {"query_id": "q2", "new_sql": "SELECT p.id, p.name, SUM(sm.quantity) FROM products p INNER JOIN stock_movements sm ON p.id = sm.product_id GROUP BY p.id, p.name"}},
{"action_type": "analyze_statistics", "payload": {"table": "stock_movements"}},
{"action_type": "submit_report", "payload": {"summary": "Index + query rewrite applied. Implicit join converted to explicit INNER JOIN."}},
],
"medium_s004": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "analyze_indexes", "payload": {"table": "tickets"}},
{"action_type": "create_index", "payload": {"table": "tickets", "columns": ["status", "priority", "created_at"]}},
{"action_type": "create_index", "payload": {"table": "tickets", "columns": ["status", "agent_id"]}},
{"action_type": "submit_report", "payload": {"summary": "Two targeted indexes on tickets table eliminate full table scans."}},
],
"medium_s005": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "create_index", "payload": {"table": "events", "columns": ["user_id", "event_type", "occurred_at"]}},
{"action_type": "create_index", "payload": {"table": "users", "columns": ["signup_source", "created_at"]}},
{"action_type": "analyze_statistics", "payload": {"table": "events"}},
{"action_type": "submit_report", "payload": {"summary": "Range query indexes added for both events and users tables."}},
],
"hard_s001": [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "inspect_query", "payload": {"query_id": "q2"}},
{"action_type": "inspect_query", "payload": {"query_id": "q3"}},
{"action_type": "analyze_indexes", "payload": {"table": "transactions"}},
{"action_type": "create_index", "payload": {"table": "transactions", "columns": ["account_id", "status", "created_at"]}},
{"action_type": "create_index", "payload": {"table": "transactions", "columns": ["customer_id", "amount"]}},
{"action_type": "create_index", "payload": {"table": "audit_log", "columns": ["entity_id", "entity_type", "created_at"]}},
{"action_type": "rewrite_query", "payload": {"query_id": "q2", "new_sql": "SELECT c.id, c.name, COUNT(t.id) as tx_count FROM customers c INNER JOIN transactions t ON c.id = t.customer_id WHERE t.amount > ? GROUP BY c.id, c.name"}},
{"action_type": "partition_table", "payload": {"table": "audit_log", "partition_by": "created_at", "partition_type": "RANGE"}},
{"action_type": "analyze_statistics", "payload": {"table": "transactions"}},
{"action_type": "submit_report", "payload": {"summary": "3 indexes added, implicit join rewritten, audit_log partitioned by date."}},
],
}
def run_expert_episode(scenario_id: str) -> list[dict]:
"""
Run one expert episode and collect (prompt, action, reward) tuples.
"""
trajectory = []
try:
# Reset
r = requests.post(f"{ENV_URL}/reset",
json={"task_id": scenario_id}, timeout=15)
obs = r.json()
actions = EXPERT_ACTIONS.get(scenario_id, [
{"action_type": "inspect_query", "payload": {"query_id": "q1"}},
{"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id"]}},
{"action_type": "submit_report", "payload": {"summary": "Optimization applied."}},
])
for action in actions:
# Build prompt from current observation
ctx = obs.get("current_context", {})
prompt = f"""You are a senior database engineer.
Current DB state:
- Performance score: {ctx.get('performance_score', 0)} / {ctx.get('target_score', 85)}
- Slow queries: {json.dumps(ctx.get('slow_queries', []))}
- Tables: {json.dumps(ctx.get('tables', []))}
- Steps remaining: {obs.get('max_steps', 50) - obs.get('step_count', 0)}
Choose the best next action as JSON:"""
# Take action
r = requests.post(f"{ENV_URL}/step", json=action, timeout=15)
data = r.json()
trajectory.append({
"scenario_id": scenario_id,
"prompt": prompt,
"action": json.dumps(action),
"reward": data.get("reward", {}).get("score", 0.001),
"db_delta": data.get("info", {}).get("db_delta", 0),
"step": data.get("observation", {}).get("step_count", 0),
})
obs = data.get("observation", obs)
if data.get("done"):
break
print(f" βœ… {scenario_id}: {len(trajectory)} steps, "
f"final reward={trajectory[-1]['reward']:.3f}")
except Exception as e:
print(f" ❌ {scenario_id}: {e}")
return trajectory
def generate_all():
"""Generate training data from all scenarios."""
print("πŸ”„ Generating training data...")
print(f"🌐 Environment: {ENV_URL}")
print(f"πŸ“ Output: {OUTPUT_DIR}")
print("─" * 50)
all_trajectories = []
total_steps = 0
for scenario_id in ALL_SCENARIOS:
print(f"Running {scenario_id}...")
trajectory = run_expert_episode(scenario_id)
all_trajectories.extend(trajectory)
total_steps += len(trajectory)
time.sleep(0.5) # Be nice to HF Space
# Save as JSONL for training
output_file = OUTPUT_DIR / "training_data.jsonl"
with open(output_file, "w") as f:
for item in all_trajectories:
f.write(json.dumps(item) + "\n")
# Also save as JSON for inspection
json_file = OUTPUT_DIR / "training_data.json"
with open(json_file, "w") as f:
json.dump(all_trajectories, f, indent=2)
print("─" * 50)
print(f"βœ… Generated {total_steps} training steps from {len(ALL_SCENARIOS)} scenarios")
print(f"πŸ“„ Saved to: {output_file}")
# Stats
rewards = [t["reward"] for t in all_trajectories]
avg_r = sum(rewards) / max(len(rewards), 1)
print(f"πŸ“Š Average reward: {avg_r:.3f}")
print(f"πŸ“Š Max reward: {max(rewards):.3f}")
print(f"πŸ“Š Min reward: {min(rewards):.3f}")
return all_trajectories
if __name__ == "__main__":
generate_all()