File size: 9,323 Bytes
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea504bf
 
 
3c1b0c7
 
 
ea504bf
3c1b0c7
 
 
ea504bf
3c1b0c7
 
 
 
 
ea504bf
3c1b0c7
 
 
 
ea504bf
3c1b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import time
from env.environment import SQLDebuggerEnvironment
from env.models import (
    Action, ActionType, DifficultyLevel,
    BaselineResponse, BaselineResult
)


#  BASELINE AGENT
#  Uses rule-based heuristics β€” no GPT-4
#  Must complete within 60 seconds
#  OPENAI_API_KEY must come from environment

def _check_api_key():
    """Edge case: OPENAI_API_KEY not set β†’ raise clear error."""
    key = os.environ.get("OPENAI_API_KEY")
    if not key:
        raise ValueError(
            "OPENAI_API_KEY environment variable is not set. "
            "Please set it before running baseline: "
            "set OPENAI_API_KEY=your-key-here"
        )
    return key


def _rule_based_agent(env: SQLDebuggerEnvironment, task: dict) -> tuple[float, int, str]:
    """
    Rule-based baseline agent that analyzes the buggy query
    and attempts a fix using heuristics.
    Fast β€” no API calls needed for baseline scoring.
    """
    context     = task.get("current_context", {})
    buggy_query = context.get("buggy_query", "")
    error_msg   = context.get("error_message", "")
    error_type  = context.get("error_type_hint", "syntax")
    category    = context.get("category", "syntax")

    steps_taken = 0
    total_reward = 0.0

    # ── Step 1: Identify the error ────────────────────────────────
    identify_payload = {
        "error_location": _guess_error_location(buggy_query, error_msg, category),
        "error_type":     error_type,
        "explanation":    f"Detected {category} issue in query: {error_msg[:100]}"
    }
    action1 = Action(
        action_type=ActionType.IDENTIFY_ERROR,
        payload=identify_payload
    )
    resp1 = env.step(action1)
    total_reward += resp1.reward.score
    steps_taken  += 1

    if resp1.done:
        return total_reward, steps_taken, resp1.reward.feedback

    # ── Step 2: Submit answer based on heuristic fix ──────────────
    fixed_query  = _apply_heuristic_fix(buggy_query, category, error_msg)
    explanation  = _generate_explanation(buggy_query, fixed_query, category)

    if category == "performance":
        action2 = Action(
            action_type=ActionType.OPTIMIZE_QUERY,
            payload={
                "optimized_query":     fixed_query,
                "optimization_type":   f"Fix {category} issue: {error_type}",
                "explanation":         explanation,
                "root_cause":          f"Performance issue detected: {error_msg[:100]}",
                "expected_improvement":"Significant reduction in query execution time",
                "confidence":          0.6
            }
        )
    else:
        action2 = Action(
            action_type=ActionType.SUBMIT_ANSWER,
            payload={
                "fixed_query":   fixed_query,
                "explanation":   explanation,
                "error_type":    error_type,
                "error_location": identify_payload["error_location"],
                "confidence":    0.6
            }
        )

    resp2 = env.step(action2)
    total_reward += resp2.reward.score
    steps_taken  += 1

    return total_reward, steps_taken, resp2.reward.feedback


def _guess_error_location(query: str, error_msg: str, category: str) -> str:
    """Heuristic: guess where the error is based on keywords."""
    q = query.upper()
    e = error_msg.upper()

    if "SELECT" in e or "COLUMN" in e:
        return "SELECT clause"
    if "WHERE" in e or "FILTER" in e:
        return "WHERE clause"
    if "JOIN" in e or "ON" in e:
        return "JOIN condition"
    if "GROUP" in e or "HAVING" in e:
        return "GROUP BY / HAVING clause"
    if "ORDER" in e:
        return "ORDER BY clause"
    if category == "performance":
        return "Query structure β€” performance bottleneck"
    return "Unknown location"


def _apply_heuristic_fix(query: str, category: str, error_msg: str) -> str:
    """
    Apply simple heuristic fixes based on category.
    Not perfect β€” baseline is meant to score low-medium,
    showing the environment has room for agent improvement.
    """
    q = query.strip()

    if category == "syntax":
        # Fix missing commas in SELECT
        if "syntax error" in error_msg.lower() and "name" in error_msg.lower():
            import re
            q = re.sub(r"SELECT\s+(\w+)\s+(\w+)", r"SELECT \1, \2", q, flags=re.IGNORECASE)

        # Fix missing WHERE
        if "WHERE" not in q.upper() and "=" in q:
            q = q.replace(" id =", " WHERE id =")
            q = q.replace(" name =", " WHERE name =")

        # Fix unclosed string
        if q.count("'") % 2 != 0:
            q = q + "'"

        # Fix ORDER β†’ ORDER BY
        import re
        q = re.sub(r"\bORDER\s+(?!BY)(\w)", r"ORDER BY \1", q, flags=re.IGNORECASE)

        # Fix GROUP β†’ GROUP BY
        q = re.sub(r"\bGROUP\s+(?!BY)(\w)", r"GROUP BY \1", q, flags=re.IGNORECASE)

    elif category == "logic":
        # Fix INNER JOIN β†’ LEFT JOIN for inclusion
        if "INNER JOIN" in q.upper():
            q = q.replace("INNER JOIN", "LEFT JOIN").replace("inner join", "LEFT JOIN")

        # Fix WHERE aggregate β†’ HAVING
        import re
        having_pattern = re.compile(
            r"WHERE\s+(AVG|SUM|COUNT|MAX|MIN)\s*\(", re.IGNORECASE
        )
        if having_pattern.search(q):
            # Move aggregate condition to HAVING
            q = having_pattern.sub("HAVING \\1(", q)

    elif category == "performance":
        # For performance issues, suggest JOIN-based rewrite
        if "SELECT *" in q.upper():
            q = q.replace("SELECT *", "SELECT id, name, status, created_at")

    return q


def _generate_explanation(buggy: str, fixed: str, category: str) -> str:
    """Generate a human-readable explanation of the fix."""
    if buggy.strip() == fixed.strip():
        return f"Analyzed the {category} issue. The query may require deeper inspection."

    explanations = {
        "syntax":      "Fixed syntax error in the SQL query by correcting the query structure.",
        "logic":       "Fixed logic error by correcting the JOIN type and query conditions.",
        "performance": "Optimized query performance by restructuring to avoid expensive operations.",
    }
    base = explanations.get(category, "Applied heuristic fix to the SQL query.")
    return f"{base} Original: '{buggy[:60]}...' Fixed: '{fixed[:60]}...'"


# ─────────────────────────────────────────────
#  MAIN BASELINE RUNNER
# ─────────────────────────────────────────────

def run_baseline() -> BaselineResponse:
    """
    Runs baseline agent against one task of each difficulty.
    Returns BaselineResponse with scores for all 3 tasks.
    Must complete within 60 seconds.
    """
    try:
        _check_api_key()
    except ValueError as e:
        print(f"Warning: {e}")

    results     = []
    difficulties = [
        (DifficultyLevel.EASY,   "easy_001"),
        (DifficultyLevel.MEDIUM, "medium_001"),
        (DifficultyLevel.HARD,   "hard_001"),
    ]

    for difficulty, task_id in difficulties:
        env = SQLDebuggerEnvironment()
        try:
            obs          = env.reset(difficulty=difficulty.value, task_id=task_id)
            task_context = {"current_context": obs.current_context}

            start        = time.time()
            score, steps, feedback = _rule_based_agent(env, task_context)
            elapsed      = time.time() - start

            # FIX 1: clamp score strictly between 0 and 1 exclusive
            safe_score = round(max(0.001, min(0.999, float(score))), 4)

            results.append(BaselineResult(
                task_id    = task_id,
                difficulty = difficulty,
                score      = safe_score,
                steps      = steps,
                feedback   = f"{feedback} (elapsed: {elapsed:.2f}s)"
            ))
            print(f"Baseline {difficulty.value}: score={safe_score}, steps={steps}")

        except Exception as e:
            results.append(BaselineResult(
                task_id    = task_id,
                difficulty = difficulty,
                score      = 0.001,  # FIX 2: was 0.0, which is an invalid boundary value
                steps      = 0,
                feedback   = f"Error: {str(e)}"
            ))

    avg = round(sum(r.score for r in results) / len(results), 4) if results else 0.5
    print(f"Baseline average score: {avg}")

    return BaselineResponse(results=results, average_score=avg)


# ─────────────────────────────────────────────
#  DIRECT RUN
# ─────────────────────────────────────────────

if __name__ == "__main__":
    print("Running baseline agent...")
    response = run_baseline()
    print(f"\nFinal Results:")
    for r in response.results:
        print(f"  {r.difficulty.value:8} | {r.task_id:12} | score={r.score} | steps={r.steps}")
    print(f"\nAverage Score: {response.average_score}")