Spaces:
Sleeping
Sleeping
| """ | |
| training/generate_training_data.py | |
| Generates training data by running episodes on the live environment. | |
| Saves (prompt, action, reward) tuples for GRPO training. | |
| Run this BEFORE train_agent.py to pre-generate data. | |
| """ | |
| import os | |
| import json | |
| import time | |
| import requests | |
| from pathlib import Path | |
| ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space") | |
| OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "./sdea-trained")) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # All Round 2 scenario IDs | |
| ALL_SCENARIOS = [ | |
| "easy_s001", "easy_s002", "easy_s003", "easy_s004", "easy_s005", | |
| "medium_s001", "medium_s002", "medium_s003", "medium_s004", "medium_s005", | |
| "hard_s001", "hard_s002", "hard_s003", "hard_s004", "hard_s005", | |
| ] | |
| # Optimal action sequences per scenario (expert demonstrations) | |
| EXPERT_ACTIONS = { | |
| "easy_s001": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "analyze_indexes", "payload": {"table": "users"}}, | |
| {"action_type": "create_index", "payload": {"table": "users", "columns": ["email"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Added index on users(email). Email lookup now uses index scan instead of full table scan."}}, | |
| ], | |
| "easy_s002": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Composite index on orders(user_id, status) eliminates full table scan."}}, | |
| ], | |
| "easy_s003": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "create_index", "payload": {"table": "products", "columns": ["name"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Index on products(name) speeds up LIKE queries."}}, | |
| ], | |
| "easy_s004": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "create_index", "payload": {"table": "sessions", "columns": ["user_id", "expires_at"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Composite index on sessions(user_id, expires_at) added."}}, | |
| ], | |
| "easy_s005": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "create_index", "payload": {"table": "logs", "columns": ["level", "created_at"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Compound index on logs(level, created_at) added."}}, | |
| ], | |
| "medium_s001": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "analyze_indexes", "payload": {"table": "orders"}}, | |
| {"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}}, | |
| {"action_type": "create_index", "payload": {"table": "users", "columns": ["country"]}}, | |
| {"action_type": "analyze_statistics", "payload": {"table": "orders"}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Two indexes added. Both slow queries now use index scans."}}, | |
| ], | |
| "medium_s002": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "create_index", "payload": {"table": "posts", "columns": ["author_id", "published", "created_at"]}}, | |
| {"action_type": "create_index", "payload": {"table": "authors", "columns": ["username"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Multi-column index on posts and unique index on authors added."}}, | |
| ], | |
| "medium_s003": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "create_index", "payload": {"table": "stock_movements", "columns": ["product_id", "movement_type", "created_at"]}}, | |
| {"action_type": "rewrite_query", "payload": {"query_id": "q2", "new_sql": "SELECT p.id, p.name, SUM(sm.quantity) FROM products p INNER JOIN stock_movements sm ON p.id = sm.product_id GROUP BY p.id, p.name"}}, | |
| {"action_type": "analyze_statistics", "payload": {"table": "stock_movements"}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Index + query rewrite applied. Implicit join converted to explicit INNER JOIN."}}, | |
| ], | |
| "medium_s004": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "analyze_indexes", "payload": {"table": "tickets"}}, | |
| {"action_type": "create_index", "payload": {"table": "tickets", "columns": ["status", "priority", "created_at"]}}, | |
| {"action_type": "create_index", "payload": {"table": "tickets", "columns": ["status", "agent_id"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Two targeted indexes on tickets table eliminate full table scans."}}, | |
| ], | |
| "medium_s005": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "create_index", "payload": {"table": "events", "columns": ["user_id", "event_type", "occurred_at"]}}, | |
| {"action_type": "create_index", "payload": {"table": "users", "columns": ["signup_source", "created_at"]}}, | |
| {"action_type": "analyze_statistics", "payload": {"table": "events"}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Range query indexes added for both events and users tables."}}, | |
| ], | |
| "hard_s001": [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q2"}}, | |
| {"action_type": "inspect_query", "payload": {"query_id": "q3"}}, | |
| {"action_type": "analyze_indexes", "payload": {"table": "transactions"}}, | |
| {"action_type": "create_index", "payload": {"table": "transactions", "columns": ["account_id", "status", "created_at"]}}, | |
| {"action_type": "create_index", "payload": {"table": "transactions", "columns": ["customer_id", "amount"]}}, | |
| {"action_type": "create_index", "payload": {"table": "audit_log", "columns": ["entity_id", "entity_type", "created_at"]}}, | |
| {"action_type": "rewrite_query", "payload": {"query_id": "q2", "new_sql": "SELECT c.id, c.name, COUNT(t.id) as tx_count FROM customers c INNER JOIN transactions t ON c.id = t.customer_id WHERE t.amount > ? GROUP BY c.id, c.name"}}, | |
| {"action_type": "partition_table", "payload": {"table": "audit_log", "partition_by": "created_at", "partition_type": "RANGE"}}, | |
| {"action_type": "analyze_statistics", "payload": {"table": "transactions"}}, | |
| {"action_type": "submit_report", "payload": {"summary": "3 indexes added, implicit join rewritten, audit_log partitioned by date."}}, | |
| ], | |
| } | |
| def run_expert_episode(scenario_id: str) -> list[dict]: | |
| """ | |
| Run one expert episode and collect (prompt, action, reward) tuples. | |
| """ | |
| trajectory = [] | |
| try: | |
| # Reset | |
| r = requests.post(f"{ENV_URL}/reset", | |
| json={"task_id": scenario_id}, timeout=15) | |
| obs = r.json() | |
| actions = EXPERT_ACTIONS.get(scenario_id, [ | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}}, | |
| {"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id"]}}, | |
| {"action_type": "submit_report", "payload": {"summary": "Optimization applied."}}, | |
| ]) | |
| for action in actions: | |
| # Build prompt from current observation | |
| ctx = obs.get("current_context", {}) | |
| prompt = f"""You are a senior database engineer. | |
| Current DB state: | |
| - Performance score: {ctx.get('performance_score', 0)} / {ctx.get('target_score', 85)} | |
| - Slow queries: {json.dumps(ctx.get('slow_queries', []))} | |
| - Tables: {json.dumps(ctx.get('tables', []))} | |
| - Steps remaining: {obs.get('max_steps', 50) - obs.get('step_count', 0)} | |
| Choose the best next action as JSON:""" | |
| # Take action | |
| r = requests.post(f"{ENV_URL}/step", json=action, timeout=15) | |
| data = r.json() | |
| trajectory.append({ | |
| "scenario_id": scenario_id, | |
| "prompt": prompt, | |
| "action": json.dumps(action), | |
| "reward": data.get("reward", {}).get("score", 0.001), | |
| "db_delta": data.get("info", {}).get("db_delta", 0), | |
| "step": data.get("observation", {}).get("step_count", 0), | |
| }) | |
| obs = data.get("observation", obs) | |
| if data.get("done"): | |
| break | |
| print(f" β {scenario_id}: {len(trajectory)} steps, " | |
| f"final reward={trajectory[-1]['reward']:.3f}") | |
| except Exception as e: | |
| print(f" β {scenario_id}: {e}") | |
| return trajectory | |
| def generate_all(): | |
| """Generate training data from all scenarios.""" | |
| print("π Generating training data...") | |
| print(f"π Environment: {ENV_URL}") | |
| print(f"π Output: {OUTPUT_DIR}") | |
| print("β" * 50) | |
| all_trajectories = [] | |
| total_steps = 0 | |
| for scenario_id in ALL_SCENARIOS: | |
| print(f"Running {scenario_id}...") | |
| trajectory = run_expert_episode(scenario_id) | |
| all_trajectories.extend(trajectory) | |
| total_steps += len(trajectory) | |
| time.sleep(0.5) # Be nice to HF Space | |
| # Save as JSONL for training | |
| output_file = OUTPUT_DIR / "training_data.jsonl" | |
| with open(output_file, "w") as f: | |
| for item in all_trajectories: | |
| f.write(json.dumps(item) + "\n") | |
| # Also save as JSON for inspection | |
| json_file = OUTPUT_DIR / "training_data.json" | |
| with open(json_file, "w") as f: | |
| json.dump(all_trajectories, f, indent=2) | |
| print("β" * 50) | |
| print(f"β Generated {total_steps} training steps from {len(ALL_SCENARIOS)} scenarios") | |
| print(f"π Saved to: {output_file}") | |
| # Stats | |
| rewards = [t["reward"] for t in all_trajectories] | |
| avg_r = sum(rewards) / max(len(rewards), 1) | |
| print(f"π Average reward: {avg_r:.3f}") | |
| print(f"π Max reward: {max(rewards):.3f}") | |
| print(f"π Min reward: {min(rewards):.3f}") | |
| return all_trajectories | |
| if __name__ == "__main__": | |
| generate_all() | |