File size: 11,581 Bytes
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
 
 
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
59876a8
ad1fdc5
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
59876a8
ad1fdc5
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
59876a8
ad1fdc5
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
59876a8
ad1fdc5
 
 
59876a8
ad1fdc5
 
59876a8
ad1fdc5
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
59876a8
ad1fdc5
 
 
 
 
59876a8
ad1fdc5
 
 
 
59876a8
ad1fdc5
 
 
 
59876a8
ad1fdc5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

import sqlite3
import json
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import SQLArenaAction, SQLArenaObservation
    from ..tasks import SQLTask, get_task, get_tasks_by_difficulty, ALL_TASKS
except ImportError:
    from models import SQLArenaAction, SQLArenaObservation
    from tasks import SQLTask, get_task, get_tasks_by_difficulty, ALL_TASKS


MAX_EXPLORE_STEPS = 5        # max explore queries allowed
MAX_RESULT_ROWS = 20         # max rows shown to agent
MAX_QUERY_TIMEOUT_MS = 5000  # query timeout in ms


class SQLArenaEnvironment(Environment):

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._conn: Optional[sqlite3.Connection] = None
        self._current_task: Optional[SQLTask] = None
        self._explore_steps_used: int = 0
        self._submitted: bool = False
        self._expected_result: Optional[List[Dict]] = None
        self._task_id_override: Optional[str] = None
        self._difficulty_override: Optional[str] = None

    # openenv required interface
    # ─────────────────────────────────────────

    def reset(self, task_id: Optional[str] = None, difficulty: Optional[str] = None) -> SQLArenaObservation:
        # close old db connection
        if self._conn:
            try:
                self._conn.close()
            except Exception:
                pass

        # pick which task to run
        if task_id:
            task = get_task(task_id)
            if not task:
                task = list(ALL_TASKS.values())[0]
        elif difficulty:
            tasks = get_tasks_by_difficulty(difficulty)
            import random
            task = random.choice(tasks) if tasks else list(ALL_TASKS.values())[0]
        else:
            import random
            task = random.choice(list(ALL_TASKS.values()))

        self._current_task = task
        self._explore_steps_used = 0
        self._submitted = False
        self._state = State(episode_id=str(uuid4()), step_count=0)

        # fresh sqlite db in memory
        self._conn = sqlite3.connect(":memory:", check_same_thread=False)
        self._conn.row_factory = sqlite3.Row

        # Load schema and seed data
        try:
            self._conn.executescript(task.schema_sql)
            self._conn.executescript(task.seed_sql)
            self._conn.commit()
        except Exception as e:
            return self._make_obs(
                query_error=f"Environment setup error: {e}",
                query_result=[],
                query_type="reset",
            )

        # run reference answer ahead of time
        self._expected_result = self._run_query_safe(task.solution_sql)

        return self._make_obs(
            query_result=[],
            query_type="reset",
            query_error=None,
        )

    def step(self, action: SQLArenaAction) -> SQLArenaObservation:
        if self._current_task is None:
            return self._make_obs(
                query_error="Environment not initialized. Call reset() first.",
                query_result=[],
                query_type=action.query_type,
                done=True,
                reward=-1.0,
            )

        if self._submitted:
            return self._make_obs(
                query_error="Episode already ended. Call reset() to start a new episode.",
                query_result=[],
                query_type=action.query_type,
                done=True,
                reward=0.0,
            )

        self._state.step_count += 1
        sql = action.sql.strip()
        query_type = action.query_type.lower()

        # explore step
        if query_type == "explore":
            # force submit if explore budget is done
            if self._explore_steps_used >= MAX_EXPLORE_STEPS:
                query_type = "submit"
            else:
                self._explore_steps_used += 1
                result, error = self._execute_safe(sql)
                # small cost per explore step
                reward = -0.02

                return self._make_obs(
                    query_result=result,
                    query_error=error,
                    query_type="explore",
                    done=False,
                    reward=reward,
                )

        # submit step
        self._submitted = True
        result, error = self._execute_safe(sql)

        if error:
            # syntax error penalty
            reward = -0.1
            feedback = f"SQL error on submission: {error}. Correct your query."
            return self._make_obs(
                query_result=[],
                query_error=error,
                query_type="submit",
                done=True,
                reward=reward,
                is_correct=False,
                feedback=feedback,
            )

        # grade the submitted query
        is_correct, partial, feedback = self._grade(result)
        if is_correct:
            reward = 1.0
        elif partial:
            reward = 0.4
        else:
            reward = 0.0

        return self._make_obs(
            query_result=result,
            query_error=None,
            query_type="submit",
            done=True,
            reward=reward,
            is_correct=is_correct,
            feedback=feedback,
        )

    @property
    def state(self) -> State:
        return self._state

    # internal helpers
    # ─────────────────────────────────────────

    def _execute_safe(self, sql: str) -> Tuple[List[Dict[str, Any]], Optional[str]]:
        if not self._conn:
            return [], "Database not initialized"
        try:
            # block dangerous ops
            sql_upper = sql.upper().strip()
            dangerous = ["DROP ", "ALTER ", "TRUNCATE ", "PRAGMA ", "ATTACH ", "DETACH "]
            if any(sql_upper.startswith(d) for d in dangerous):
                return [], "Operation not permitted in this environment"

            cursor = self._conn.execute(sql)
            rows = cursor.fetchmany(MAX_RESULT_ROWS + 1)
            truncated = len(rows) > MAX_RESULT_ROWS
            rows = rows[:MAX_RESULT_ROWS]
            result = [dict(row) for row in rows]
            if truncated:
                result.append({"__info__": f"Results truncated to {MAX_RESULT_ROWS} rows"})
            return result, None
        except sqlite3.Error as e:
            return [], str(e)
        except Exception as e:
            return [], f"Unexpected error: {e}"

    def _run_query_safe(self, sql: str) -> Optional[List[Dict[str, Any]]]:
        if not self._conn:
            return None
        try:
            cursor = self._conn.execute(sql)
            rows = cursor.fetchall()
            return [dict(row) for row in rows]
        except Exception:
            return None

    def _grade(self, agent_result: List[Dict]) -> Tuple[bool, bool, str]:
        expected = self._expected_result
        if expected is None:
            return False, False, "Could not compute expected result. Contact organizers."

        # compare both result sets
        def normalize(rows: List[Dict]) -> List[str]:
            normalized = []
            for row in rows:
                # skip internal marker rows
                clean = {k: v for k, v in row.items() if not k.startswith("__")}
                # round floats to avoid precision noise
                rounded = {}
                for k, v in clean.items():
                    if isinstance(v, float):
                        rounded[k] = round(v, 2)
                    else:
                        rounded[k] = v
                normalized.append(json.dumps(rounded, sort_keys=True))
            return sorted(normalized)

        # remove info markers from agent result
        agent_clean = [r for r in agent_result if not any(k.startswith("__") for k in r.keys())]

        expected_norm = normalize(expected)
        agent_norm = normalize(agent_clean)

        # exact match
        if expected_norm == agent_norm:
            return True, False, f"βœ“ Correct! Your query returned the exact expected result ({len(expected)} rows)."

        # check partial match conditions
        same_count = len(agent_clean) == len(expected)

        # same column names?
        exp_cols = set(expected[0].keys()) if expected else set()
        agent_cols = set(agent_clean[0].keys()) if agent_clean else set()
        same_cols = exp_cols == agent_cols

        # count matching rows
        matching_rows = len(set(expected_norm) & set(agent_norm))
        match_pct = matching_rows / max(len(expected_norm), 1) * 100

        if same_cols and same_count and match_pct >= 50:
            return False, True, (
                f"Partial credit. Correct columns, {same_count} rows, "
                f"but {matching_rows}/{len(expected)} rows match exactly. "
                f"Check your WHERE conditions or aggregation."
            )

        if same_cols and not same_count:
            return False, True, (
                f"Partial credit. Correct columns but wrong row count: "
                f"got {len(agent_clean)}, expected {len(expected)}. "
                f"Check your filters."
            ) if match_pct >= 30 else (False, False, (
                f"Wrong answer. Got {len(agent_clean)} rows, expected {len(expected)}. "
                f"Expected columns: {sorted(exp_cols)}."
            ))

        feedback = (
            f"Wrong answer. Expected {len(expected)} rows with columns {sorted(exp_cols)}. "
            f"You returned {len(agent_clean)} rows"
            + (f" with columns {sorted(agent_cols)}." if agent_clean else " (empty result).")
        )
        return False, False, feedback

    def _make_obs(
        self,
        query_result: List[Dict],
        query_type: str,
        query_error: Optional[str] = None,
        done: bool = False,
        reward: float = 0.0,
        is_correct: Optional[bool] = None,
        feedback: Optional[str] = None,
    ) -> SQLArenaObservation:
        task = self._current_task
        expected_count = len(self._expected_result) if self._expected_result else None

        return SQLArenaObservation(
            # task info
            task_id=task.task_id if task else "",
            difficulty=task.difficulty if task else "easy",
            question=task.question if task else "",
            schema_info=task.schema_description if task else "",

            # query result
            query_result=query_result,
            query_error=query_error,
            query_type=query_type,
            rows_returned=len([r for r in query_result if not any(k.startswith("__") for k in r.keys())]),

            # episode progress
            explore_steps_used=self._explore_steps_used,
            explore_steps_remaining=max(0, MAX_EXPLORE_STEPS - self._explore_steps_used),
            submitted=self._submitted,

            # feedback fields
            is_correct=is_correct,
            feedback=feedback,
            expected_row_count=expected_count,

            # base openenv fields
            done=done,
            reward=reward,
            metadata={
                "episode_id": self._state.episode_id,
                "step_count": self._state.step_count,
            },
        )