Spaces:
Runtime error
Runtime error
| import sqlite3 | |
| import random | |
| from typing import Any, Optional, Tuple | |
| from openenv.core.env_server.interfaces import Environment | |
| from models import SQLAction, SQLObservation, SQLState | |
| from server.challenges import CHALLENGES | |
| def _run_query(schema_sql: str, query: str) -> Tuple[bool, str]: | |
| """ | |
| Execute query against an in-memory SQLite DB seeded with schema_sql. | |
| Returns (success: bool, result_string: str). | |
| """ | |
| try: | |
| conn = sqlite3.connect(":memory:") | |
| conn.executescript(schema_sql) | |
| cursor = conn.execute(query) | |
| rows = cursor.fetchall() | |
| col_names = [desc[0] for desc in cursor.description] if cursor.description else [] | |
| conn.close() | |
| if not rows: | |
| return True, "(no rows returned)" | |
| # Format as a simple text table | |
| header = " | ".join(col_names) | |
| sep = "-" * len(header) | |
| row_lines = [" | ".join(str(v) for v in row) for row in rows] | |
| return True, "\n".join([header, sep] + row_lines) | |
| except Exception as e: | |
| return False, f"ERROR: {e}" | |
| def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool: | |
| """Check whether two queries return identical result sets.""" | |
| try: | |
| conn = sqlite3.connect(":memory:") | |
| conn.executescript(schema_sql) | |
| rows_a = set(conn.execute(query_a).fetchall()) | |
| rows_b = set(conn.execute(query_b).fetchall()) | |
| conn.close() | |
| return rows_a == rows_b | |
| except Exception: | |
| return False | |
| class SQLTutorEnvironment(Environment[SQLAction, SQLObservation, SQLState]): | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = SQLState() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> SQLObservation: | |
| if seed is not None: | |
| random.seed(seed) | |
| challenge = random.choice(CHALLENGES) | |
| state = SQLState( | |
| challenge_id=challenge["id"], | |
| broken_query=challenge["broken_query"], | |
| correct_query=challenge["correct_query"], | |
| schema_sql=challenge["schema_sql"], | |
| schema_description=challenge["schema_description"], | |
| task_description=challenge["task_description"], | |
| hints=challenge["hints"], | |
| steps_taken=0, | |
| max_steps=5, | |
| hints_used=0, | |
| is_resolved=False, | |
| cumulative_reward=0.0, | |
| episode_id=episode_id, | |
| step_count=0, | |
| ) | |
| self._state = state | |
| # Show the agent the broken query output so it understands what's wrong | |
| _, broken_result = _run_query(state.schema_sql, state.broken_query) | |
| observation = SQLObservation( | |
| broken_query=state.broken_query, | |
| schema_description=state.schema_description, | |
| task_description=state.task_description, | |
| execution_result=f"Current (broken) query output:\n{broken_result}", | |
| is_correct=False, | |
| hint=None, | |
| steps_taken=0, | |
| max_steps=state.max_steps, | |
| hints_used=0, | |
| done=False, | |
| reward=None, | |
| ) | |
| return observation | |
| def step( | |
| self, | |
| action: SQLAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> SQLObservation: | |
| state = self._state | |
| state.steps_taken += 1 | |
| state.step_count += 1 | |
| reward = 0.0 | |
| done = False | |
| hint = None | |
| if action.action_type == "request_hint": | |
| hint_index = min(state.hints_used, len(state.hints) - 1) | |
| hint = state.hints[hint_index] | |
| state.hints_used += 1 | |
| reward = -0.1 # small penalty for using a hint | |
| execution_result = f"Current (broken) query output shown for reference." | |
| _, execution_result = _run_query(state.schema_sql, state.broken_query) | |
| execution_result = f"(Hint requested - no query executed)\nBroken query output:\n{execution_result}" | |
| is_correct = False | |
| elif action.action_type == "submit_fix": | |
| if not action.sql_query: | |
| execution_result = "ERROR: You chose 'submit_fix' but provided no sql_query." | |
| is_correct = False | |
| reward = -0.05 | |
| else: | |
| success, execution_result = _run_query(state.schema_sql, action.sql_query) | |
| if not success: | |
| is_correct = False | |
| reward = -0.1 | |
| else: | |
| is_correct = _results_match( | |
| state.schema_sql, action.sql_query, state.correct_query | |
| ) | |
| if is_correct: | |
| # Reward decreases with hints used and steps taken | |
| base_reward = 1.0 | |
| hint_penalty = 0.15 * state.hints_used | |
| step_penalty = 0.05 * max(0, state.steps_taken - 1) | |
| reward = max(0.1, base_reward - hint_penalty - step_penalty) | |
| state.is_resolved = True | |
| done = True | |
| else: | |
| reward = -0.05 | |
| else: | |
| execution_result = f"ERROR: Unknown action_type '{action.action_type}'. Use 'submit_fix' or 'request_hint'." | |
| is_correct = False | |
| reward = -0.05 | |
| # End episode if max steps reached | |
| if state.steps_taken >= state.max_steps and not done: | |
| done = True | |
| state.cumulative_reward += reward | |
| observation = SQLObservation( | |
| broken_query=state.broken_query, | |
| schema_description=state.schema_description, | |
| task_description=state.task_description, | |
| execution_result=execution_result, | |
| is_correct=is_correct, | |
| hint=hint, | |
| steps_taken=state.steps_taken, | |
| max_steps=state.max_steps, | |
| hints_used=state.hints_used, | |
| done=done, | |
| reward=reward, | |
| ) | |
| return observation | |
| def state(self) -> SQLState: | |
| return self._state | |