sql-db-engineer-agent / baseline.py
junaid0600's picture
changed
ea504bf
Raw
History Blame Contribute Delete
9.32 kB
import os
import time
from env.environment import SQLDebuggerEnvironment
from env.models import (
Action, ActionType, DifficultyLevel,
BaselineResponse, BaselineResult
)
# BASELINE AGENT
# Uses rule-based heuristics β€” no GPT-4
# Must complete within 60 seconds
# OPENAI_API_KEY must come from environment
def _check_api_key():
"""Edge case: OPENAI_API_KEY not set β†’ raise clear error."""
key = os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError(
"OPENAI_API_KEY environment variable is not set. "
"Please set it before running baseline: "
"set OPENAI_API_KEY=your-key-here"
)
return key
def _rule_based_agent(env: SQLDebuggerEnvironment, task: dict) -> tuple[float, int, str]:
"""
Rule-based baseline agent that analyzes the buggy query
and attempts a fix using heuristics.
Fast β€” no API calls needed for baseline scoring.
"""
context = task.get("current_context", {})
buggy_query = context.get("buggy_query", "")
error_msg = context.get("error_message", "")
error_type = context.get("error_type_hint", "syntax")
category = context.get("category", "syntax")
steps_taken = 0
total_reward = 0.0
# ── Step 1: Identify the error ────────────────────────────────
identify_payload = {
"error_location": _guess_error_location(buggy_query, error_msg, category),
"error_type": error_type,
"explanation": f"Detected {category} issue in query: {error_msg[:100]}"
}
action1 = Action(
action_type=ActionType.IDENTIFY_ERROR,
payload=identify_payload
)
resp1 = env.step(action1)
total_reward += resp1.reward.score
steps_taken += 1
if resp1.done:
return total_reward, steps_taken, resp1.reward.feedback
# ── Step 2: Submit answer based on heuristic fix ──────────────
fixed_query = _apply_heuristic_fix(buggy_query, category, error_msg)
explanation = _generate_explanation(buggy_query, fixed_query, category)
if category == "performance":
action2 = Action(
action_type=ActionType.OPTIMIZE_QUERY,
payload={
"optimized_query": fixed_query,
"optimization_type": f"Fix {category} issue: {error_type}",
"explanation": explanation,
"root_cause": f"Performance issue detected: {error_msg[:100]}",
"expected_improvement":"Significant reduction in query execution time",
"confidence": 0.6
}
)
else:
action2 = Action(
action_type=ActionType.SUBMIT_ANSWER,
payload={
"fixed_query": fixed_query,
"explanation": explanation,
"error_type": error_type,
"error_location": identify_payload["error_location"],
"confidence": 0.6
}
)
resp2 = env.step(action2)
total_reward += resp2.reward.score
steps_taken += 1
return total_reward, steps_taken, resp2.reward.feedback
def _guess_error_location(query: str, error_msg: str, category: str) -> str:
"""Heuristic: guess where the error is based on keywords."""
q = query.upper()
e = error_msg.upper()
if "SELECT" in e or "COLUMN" in e:
return "SELECT clause"
if "WHERE" in e or "FILTER" in e:
return "WHERE clause"
if "JOIN" in e or "ON" in e:
return "JOIN condition"
if "GROUP" in e or "HAVING" in e:
return "GROUP BY / HAVING clause"
if "ORDER" in e:
return "ORDER BY clause"
if category == "performance":
return "Query structure β€” performance bottleneck"
return "Unknown location"
def _apply_heuristic_fix(query: str, category: str, error_msg: str) -> str:
"""
Apply simple heuristic fixes based on category.
Not perfect β€” baseline is meant to score low-medium,
showing the environment has room for agent improvement.
"""
q = query.strip()
if category == "syntax":
# Fix missing commas in SELECT
if "syntax error" in error_msg.lower() and "name" in error_msg.lower():
import re
q = re.sub(r"SELECT\s+(\w+)\s+(\w+)", r"SELECT \1, \2", q, flags=re.IGNORECASE)
# Fix missing WHERE
if "WHERE" not in q.upper() and "=" in q:
q = q.replace(" id =", " WHERE id =")
q = q.replace(" name =", " WHERE name =")
# Fix unclosed string
if q.count("'") % 2 != 0:
q = q + "'"
# Fix ORDER β†’ ORDER BY
import re
q = re.sub(r"\bORDER\s+(?!BY)(\w)", r"ORDER BY \1", q, flags=re.IGNORECASE)
# Fix GROUP β†’ GROUP BY
q = re.sub(r"\bGROUP\s+(?!BY)(\w)", r"GROUP BY \1", q, flags=re.IGNORECASE)
elif category == "logic":
# Fix INNER JOIN β†’ LEFT JOIN for inclusion
if "INNER JOIN" in q.upper():
q = q.replace("INNER JOIN", "LEFT JOIN").replace("inner join", "LEFT JOIN")
# Fix WHERE aggregate β†’ HAVING
import re
having_pattern = re.compile(
r"WHERE\s+(AVG|SUM|COUNT|MAX|MIN)\s*\(", re.IGNORECASE
)
if having_pattern.search(q):
# Move aggregate condition to HAVING
q = having_pattern.sub("HAVING \\1(", q)
elif category == "performance":
# For performance issues, suggest JOIN-based rewrite
if "SELECT *" in q.upper():
q = q.replace("SELECT *", "SELECT id, name, status, created_at")
return q
def _generate_explanation(buggy: str, fixed: str, category: str) -> str:
"""Generate a human-readable explanation of the fix."""
if buggy.strip() == fixed.strip():
return f"Analyzed the {category} issue. The query may require deeper inspection."
explanations = {
"syntax": "Fixed syntax error in the SQL query by correcting the query structure.",
"logic": "Fixed logic error by correcting the JOIN type and query conditions.",
"performance": "Optimized query performance by restructuring to avoid expensive operations.",
}
base = explanations.get(category, "Applied heuristic fix to the SQL query.")
return f"{base} Original: '{buggy[:60]}...' Fixed: '{fixed[:60]}...'"
# ─────────────────────────────────────────────
# MAIN BASELINE RUNNER
# ─────────────────────────────────────────────
def run_baseline() -> BaselineResponse:
"""
Runs baseline agent against one task of each difficulty.
Returns BaselineResponse with scores for all 3 tasks.
Must complete within 60 seconds.
"""
try:
_check_api_key()
except ValueError as e:
print(f"Warning: {e}")
results = []
difficulties = [
(DifficultyLevel.EASY, "easy_001"),
(DifficultyLevel.MEDIUM, "medium_001"),
(DifficultyLevel.HARD, "hard_001"),
]
for difficulty, task_id in difficulties:
env = SQLDebuggerEnvironment()
try:
obs = env.reset(difficulty=difficulty.value, task_id=task_id)
task_context = {"current_context": obs.current_context}
start = time.time()
score, steps, feedback = _rule_based_agent(env, task_context)
elapsed = time.time() - start
# FIX 1: clamp score strictly between 0 and 1 exclusive
safe_score = round(max(0.001, min(0.999, float(score))), 4)
results.append(BaselineResult(
task_id = task_id,
difficulty = difficulty,
score = safe_score,
steps = steps,
feedback = f"{feedback} (elapsed: {elapsed:.2f}s)"
))
print(f"Baseline {difficulty.value}: score={safe_score}, steps={steps}")
except Exception as e:
results.append(BaselineResult(
task_id = task_id,
difficulty = difficulty,
score = 0.001, # FIX 2: was 0.0, which is an invalid boundary value
steps = 0,
feedback = f"Error: {str(e)}"
))
avg = round(sum(r.score for r in results) / len(results), 4) if results else 0.5
print(f"Baseline average score: {avg}")
return BaselineResponse(results=results, average_score=avg)
# ─────────────────────────────────────────────
# DIRECT RUN
# ─────────────────────────────────────────────
if __name__ == "__main__":
print("Running baseline agent...")
response = run_baseline()
print(f"\nFinal Results:")
for r in response.results:
print(f" {r.difficulty.value:8} | {r.task_id:12} | score={r.score} | steps={r.steps}")
print(f"\nAverage Score: {response.average_score}")