Spaces:
Sleeping
Sleeping
File size: 9,323 Bytes
3c1b0c7 ea504bf 3c1b0c7 ea504bf 3c1b0c7 ea504bf 3c1b0c7 ea504bf 3c1b0c7 ea504bf 3c1b0c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 | import os
import time
from env.environment import SQLDebuggerEnvironment
from env.models import (
Action, ActionType, DifficultyLevel,
BaselineResponse, BaselineResult
)
# BASELINE AGENT
# Uses rule-based heuristics β no GPT-4
# Must complete within 60 seconds
# OPENAI_API_KEY must come from environment
def _check_api_key():
"""Edge case: OPENAI_API_KEY not set β raise clear error."""
key = os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError(
"OPENAI_API_KEY environment variable is not set. "
"Please set it before running baseline: "
"set OPENAI_API_KEY=your-key-here"
)
return key
def _rule_based_agent(env: SQLDebuggerEnvironment, task: dict) -> tuple[float, int, str]:
"""
Rule-based baseline agent that analyzes the buggy query
and attempts a fix using heuristics.
Fast β no API calls needed for baseline scoring.
"""
context = task.get("current_context", {})
buggy_query = context.get("buggy_query", "")
error_msg = context.get("error_message", "")
error_type = context.get("error_type_hint", "syntax")
category = context.get("category", "syntax")
steps_taken = 0
total_reward = 0.0
# ββ Step 1: Identify the error ββββββββββββββββββββββββββββββββ
identify_payload = {
"error_location": _guess_error_location(buggy_query, error_msg, category),
"error_type": error_type,
"explanation": f"Detected {category} issue in query: {error_msg[:100]}"
}
action1 = Action(
action_type=ActionType.IDENTIFY_ERROR,
payload=identify_payload
)
resp1 = env.step(action1)
total_reward += resp1.reward.score
steps_taken += 1
if resp1.done:
return total_reward, steps_taken, resp1.reward.feedback
# ββ Step 2: Submit answer based on heuristic fix ββββββββββββββ
fixed_query = _apply_heuristic_fix(buggy_query, category, error_msg)
explanation = _generate_explanation(buggy_query, fixed_query, category)
if category == "performance":
action2 = Action(
action_type=ActionType.OPTIMIZE_QUERY,
payload={
"optimized_query": fixed_query,
"optimization_type": f"Fix {category} issue: {error_type}",
"explanation": explanation,
"root_cause": f"Performance issue detected: {error_msg[:100]}",
"expected_improvement":"Significant reduction in query execution time",
"confidence": 0.6
}
)
else:
action2 = Action(
action_type=ActionType.SUBMIT_ANSWER,
payload={
"fixed_query": fixed_query,
"explanation": explanation,
"error_type": error_type,
"error_location": identify_payload["error_location"],
"confidence": 0.6
}
)
resp2 = env.step(action2)
total_reward += resp2.reward.score
steps_taken += 1
return total_reward, steps_taken, resp2.reward.feedback
def _guess_error_location(query: str, error_msg: str, category: str) -> str:
"""Heuristic: guess where the error is based on keywords."""
q = query.upper()
e = error_msg.upper()
if "SELECT" in e or "COLUMN" in e:
return "SELECT clause"
if "WHERE" in e or "FILTER" in e:
return "WHERE clause"
if "JOIN" in e or "ON" in e:
return "JOIN condition"
if "GROUP" in e or "HAVING" in e:
return "GROUP BY / HAVING clause"
if "ORDER" in e:
return "ORDER BY clause"
if category == "performance":
return "Query structure β performance bottleneck"
return "Unknown location"
def _apply_heuristic_fix(query: str, category: str, error_msg: str) -> str:
"""
Apply simple heuristic fixes based on category.
Not perfect β baseline is meant to score low-medium,
showing the environment has room for agent improvement.
"""
q = query.strip()
if category == "syntax":
# Fix missing commas in SELECT
if "syntax error" in error_msg.lower() and "name" in error_msg.lower():
import re
q = re.sub(r"SELECT\s+(\w+)\s+(\w+)", r"SELECT \1, \2", q, flags=re.IGNORECASE)
# Fix missing WHERE
if "WHERE" not in q.upper() and "=" in q:
q = q.replace(" id =", " WHERE id =")
q = q.replace(" name =", " WHERE name =")
# Fix unclosed string
if q.count("'") % 2 != 0:
q = q + "'"
# Fix ORDER β ORDER BY
import re
q = re.sub(r"\bORDER\s+(?!BY)(\w)", r"ORDER BY \1", q, flags=re.IGNORECASE)
# Fix GROUP β GROUP BY
q = re.sub(r"\bGROUP\s+(?!BY)(\w)", r"GROUP BY \1", q, flags=re.IGNORECASE)
elif category == "logic":
# Fix INNER JOIN β LEFT JOIN for inclusion
if "INNER JOIN" in q.upper():
q = q.replace("INNER JOIN", "LEFT JOIN").replace("inner join", "LEFT JOIN")
# Fix WHERE aggregate β HAVING
import re
having_pattern = re.compile(
r"WHERE\s+(AVG|SUM|COUNT|MAX|MIN)\s*\(", re.IGNORECASE
)
if having_pattern.search(q):
# Move aggregate condition to HAVING
q = having_pattern.sub("HAVING \\1(", q)
elif category == "performance":
# For performance issues, suggest JOIN-based rewrite
if "SELECT *" in q.upper():
q = q.replace("SELECT *", "SELECT id, name, status, created_at")
return q
def _generate_explanation(buggy: str, fixed: str, category: str) -> str:
"""Generate a human-readable explanation of the fix."""
if buggy.strip() == fixed.strip():
return f"Analyzed the {category} issue. The query may require deeper inspection."
explanations = {
"syntax": "Fixed syntax error in the SQL query by correcting the query structure.",
"logic": "Fixed logic error by correcting the JOIN type and query conditions.",
"performance": "Optimized query performance by restructuring to avoid expensive operations.",
}
base = explanations.get(category, "Applied heuristic fix to the SQL query.")
return f"{base} Original: '{buggy[:60]}...' Fixed: '{fixed[:60]}...'"
# βββββββββββββββββββββββββββββββββββββββββββββ
# MAIN BASELINE RUNNER
# βββββββββββββββββββββββββββββββββββββββββββββ
def run_baseline() -> BaselineResponse:
"""
Runs baseline agent against one task of each difficulty.
Returns BaselineResponse with scores for all 3 tasks.
Must complete within 60 seconds.
"""
try:
_check_api_key()
except ValueError as e:
print(f"Warning: {e}")
results = []
difficulties = [
(DifficultyLevel.EASY, "easy_001"),
(DifficultyLevel.MEDIUM, "medium_001"),
(DifficultyLevel.HARD, "hard_001"),
]
for difficulty, task_id in difficulties:
env = SQLDebuggerEnvironment()
try:
obs = env.reset(difficulty=difficulty.value, task_id=task_id)
task_context = {"current_context": obs.current_context}
start = time.time()
score, steps, feedback = _rule_based_agent(env, task_context)
elapsed = time.time() - start
# FIX 1: clamp score strictly between 0 and 1 exclusive
safe_score = round(max(0.001, min(0.999, float(score))), 4)
results.append(BaselineResult(
task_id = task_id,
difficulty = difficulty,
score = safe_score,
steps = steps,
feedback = f"{feedback} (elapsed: {elapsed:.2f}s)"
))
print(f"Baseline {difficulty.value}: score={safe_score}, steps={steps}")
except Exception as e:
results.append(BaselineResult(
task_id = task_id,
difficulty = difficulty,
score = 0.001, # FIX 2: was 0.0, which is an invalid boundary value
steps = 0,
feedback = f"Error: {str(e)}"
))
avg = round(sum(r.score for r in results) / len(results), 4) if results else 0.5
print(f"Baseline average score: {avg}")
return BaselineResponse(results=results, average_score=avg)
# βββββββββββββββββββββββββββββββββββββββββββββ
# DIRECT RUN
# βββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print("Running baseline agent...")
response = run_baseline()
print(f"\nFinal Results:")
for r in response.results:
print(f" {r.difficulty.value:8} | {r.task_id:12} | score={r.score} | steps={r.steps}")
print(f"\nAverage Score: {response.average_score}") |