junaid0600's picture
Round 2: SQL Database Engineer Agent - 24/24 tests passing
8cb206e
Raw
History Blame Contribute Delete
14.4 kB
import json
import random
from pathlib import Path
from env.models import DifficultyLevel, TaskInfo
# ─────────────────────────────────────────────
# LOAD DATASETS β€” Round 1 + Round 2
# ─────────────────────────────────────────────
BASE_DIR = Path(__file__).parent.parent / "dataset"
def _load(filename: str) -> list[dict]:
path = BASE_DIR / filename
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# Round 1 cases (keep for backward compatibility)
EASY_CASES = _load("easy_cases.json")
MEDIUM_CASES = _load("medium_cases.json")
HARD_CASES = _load("hard_cases.json")
# Round 2 scenarios (new long-horizon DB engineering tasks)
EASY_SCENARIOS = _load("easy_scenarios.json")
MEDIUM_SCENARIOS = _load("medium_scenarios.json")
HARD_SCENARIOS = _load("hard_scenarios.json")
# Combined pools β€” Round 2 scenarios take priority (listed first)
ALL_CASES: dict[str, list[dict]] = {
DifficultyLevel.EASY: EASY_SCENARIOS + EASY_CASES,
DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS + MEDIUM_CASES,
DifficultyLevel.HARD: HARD_SCENARIOS + HARD_CASES,
}
# Round 2 only (for training pipeline)
SCENARIO_ONLY: dict[str, list[dict]] = {
DifficultyLevel.EASY: EASY_SCENARIOS,
DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS,
DifficultyLevel.HARD: HARD_SCENARIOS,
}
# ─────────────────────────────────────────────
# ACTION SCHEMA (required by /tasks validator)
# ─────────────────────────────────────────────
ACTION_SCHEMA = {
# ── Round 1 actions ──────────────────────────────────────────
"identify_error": {
"description": "Identify where and what the error is without fixing it yet",
"payload_fields": {
"error_location": {"type": "string", "required": True, "description": "Where in the query the error occurs"},
"error_type": {"type": "string", "required": True, "description": "Type: syntax | logic | performance"},
"explanation": {"type": "string", "required": False, "description": "Brief explanation of the error"}
}
},
"propose_fix": {
"description": "Propose a fix without submitting as final answer",
"payload_fields": {
"fixed_query": {"type": "string", "required": True, "description": "The proposed corrected SQL query"},
"change_made": {"type": "string", "required": True, "description": "What specifically was changed"},
"confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"}
}
},
"submit_answer": {
"description": "Submit the final fixed query as the definitive answer",
"payload_fields": {
"fixed_query": {"type": "string", "required": True, "description": "Final corrected SQL query"},
"explanation": {"type": "string", "required": True, "description": "Full explanation of fix"},
"error_type": {"type": "string", "required": False, "description": "syntax | logic | performance"},
"confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
}
},
"request_hint": {
"description": "Request a hint β€” costs 0.10 reward penalty per hint",
"payload_fields": {
"hint_type": {"type": "string", "required": False, "description": "location | error_type | fix_direction"}
}
},
"explain_issue": {
"description": "Explain the issue in detail",
"payload_fields": {
"explanation": {"type": "string", "required": True, "description": "Detailed explanation"},
"impact": {"type": "string", "required": False, "description": "Impact on query performance"},
"root_cause": {"type": "string", "required": False, "description": "Root cause analysis"}
}
},
"optimize_query": {
"description": "Submit an optimized version of the query",
"payload_fields": {
"optimized_query": {"type": "string", "required": True, "description": "Optimized SQL"},
"optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"},
"expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain"},
"explanation": {"type": "string", "required": False, "description": "Why this optimization works"},
"confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
}
},
# ── Round 2 actions ──────────────────────────────────────────
"inspect_query": {
"description": "EXPLAIN a slow query β€” reveals scan type, rows examined, index usage",
"payload_fields": {
"query_id": {"type": "string", "required": True, "description": "ID of slow query to inspect (e.g. 'q1')"}
}
},
"analyze_indexes": {
"description": "Show all indexes on a table + usage frequency + missing index hints",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table name to analyze"}
}
},
"create_index": {
"description": "Add a composite index on specified columns β€” core optimization action",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table to index"},
"columns": {"type": "list|string", "required": True, "description": "Columns to index (list or comma-separated string)"}
}
},
"rewrite_query": {
"description": "Submit a rewritten SQL query β€” system evaluates execution time improvement",
"payload_fields": {
"query_id": {"type": "string", "required": True, "description": "ID of query to rewrite"},
"new_sql": {"type": "string", "required": True, "description": "Rewritten SQL query"}
}
},
"add_column": {
"description": "Add a denormalization column to reduce expensive JOINs",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table to modify"},
"column": {"type": "string", "required": True, "description": "New column name"},
"purpose": {"type": "string", "required": False, "description": "Why this column helps"}
}
},
"drop_index": {
"description": "Remove an unused index to reduce write overhead",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table name"},
"index_name": {"type": "string", "required": True, "description": "Index name to drop (cannot drop PRIMARY)"}
}
},
"partition_table": {
"description": "Partition a large table by date or ID range for range query efficiency",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table to partition"},
"partition_by": {"type": "string", "required": False, "description": "Column to partition on (e.g. 'created_at')"},
"partition_type": {"type": "string", "required": False, "description": "RANGE | LIST | HASH"}
}
},
"analyze_statistics": {
"description": "Update table statistics for query planner accuracy",
"payload_fields": {
"table": {"type": "string", "required": True, "description": "Table to analyze"}
}
},
"submit_report": {
"description": "TERMINAL: Submit final optimization report β€” ends episode, computes full score",
"payload_fields": {
"summary": {"type": "string", "required": True, "description": "Summary of optimizations applied"},
"actions_taken": {"type": "list", "required": False, "description": "List of key actions taken"},
"expected_gain": {"type": "string", "required": False, "description": "Expected performance improvement"}
}
},
}
# ─────────────────────────────────────────────
# TASK MANAGER
# ─────────────────────────────────────────────
class TaskManager:
"""
Manages task selection for both Round 1 and Round 2 scenarios.
Round 2 scenarios have tables/slow_queries structure.
Round 1 cases have buggy_query structure.
"""
def __init__(self):
self._used_ids: set[str] = set()
def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict:
"""
Returns a task for the given difficulty.
Prefers Round 2 scenarios, falls back to Round 1 cases.
"""
pool = ALL_CASES[difficulty]
if task_id:
for case in pool:
if case["id"] == task_id:
return case
raise ValueError(f"Task '{task_id}' not found in {difficulty} pool")
# Avoid recently used tasks
available = [c for c in pool if c["id"] not in self._used_ids]
if not available:
self._used_ids.clear()
available = pool
task = random.choice(available)
self._used_ids.add(task["id"])
return task
def get_random_task(self) -> dict:
difficulty = random.choice(list(DifficultyLevel))
return self.get_task(difficulty)
def get_scenario(self, difficulty: DifficultyLevel, scenario_id: str | None = None) -> dict:
"""Get Round 2 scenario specifically."""
pool = SCENARIO_ONLY[difficulty]
if scenario_id:
for s in pool:
if s["id"] == scenario_id:
return s
raise ValueError(f"Scenario '{scenario_id}' not found")
return random.choice(pool)
def build_observation_context(self, task: dict) -> dict:
"""
Builds current_context for the Observation.
Handles both Round 2 scenario format and Round 1 case format.
CRITICAL: Never leaks ground truth (fixed_query / optimal_actions).
"""
# ── Round 2 scenario format ───────────────────────────────
if "slow_queries" in task:
return {
"scenario_id": task["id"],
"description": task.get("description", ""),
"tables": task.get("tables", []),
"slow_queries": task.get("slow_queries", []),
"performance_score_baseline": task.get("performance_score_baseline", 0.0),
"target_score": task.get("target_score", 85.0),
"max_steps": task.get("max_steps", 50),
"category": task.get("category", ""),
# Do NOT include missing_index_hints (that's the answer)
# Do NOT include optimal_actions (that's the answer)
}
# ── Round 1 case format (backward compatible) ────────────
context = {
"buggy_query": task.get("buggy_query", ""),
"error_message": task.get("error_message", ""),
"database_schema": task.get("database_schema", ""),
"error_type_hint": task.get("error_type", ""),
"category": task.get("category", ""),
"estimated_steps": task.get("estimated_fix_steps", 5),
}
if task.get("performance_issue"):
context["performance_issue"] = {
"type": task["performance_issue"]["type"],
"impact": task["performance_issue"]["impact"],
}
if task.get("expected_output") and isinstance(task["expected_output"], list):
context["expected_output_sample"] = task["expected_output"][:1]
return context
def get_hint(self, task: dict, hint_number: int) -> str:
"""Progressive hints. Each hint reveals more info. Costs -0.10 each."""
# Round 2 scenario hints
if "slow_queries" in task:
hints = [
f"Hint 1: Start by inspecting your slow queries with inspect_query action.",
f"Hint 2: Use analyze_indexes on tables appearing in slow queries.",
f"Hint 3: Category is '{task.get('category', 'indexing')}'. Target score: {task.get('target_score', 85.0)}.",
]
else:
# Round 1 hints
hints = [
f"Hint 1: The error is in the {task.get('error_location', 'query')}.",
f"Hint 2: This is a {task.get('error_type', 'unknown')} error. Category: {task.get('category')}.",
f"Hint 3: Fix: {task.get('fix_description', 'Review the query carefully.')}",
]
idx = min(hint_number - 1, len(hints) - 1)
return hints[max(0, idx)]
def list_all_tasks(self) -> list[TaskInfo]:
"""Returns TaskInfo list for the /tasks endpoint β€” all 30 tasks."""
result = []
for difficulty, cases in ALL_CASES.items():
for case in cases:
result.append(TaskInfo(
id = case["id"],
difficulty = difficulty,
description = case.get("description", ""),
action_schema = ACTION_SCHEMA
))
return result
def get_ground_truth(self, task_id: str) -> dict | None:
"""Returns full task including ground truth (used by grader only)."""
for cases in ALL_CASES.values():
for case in cases:
if case["id"] == task_id:
return case
return None
# Singleton instance
task_manager = TaskManager()