import sqlite3 import random from typing import Any, Optional, Tuple from openenv.core.env_server.interfaces import Environment from models import SQLAction, SQLObservation, SQLState from server.challenges import CHALLENGES def _run_query(schema_sql: str, query: str) -> Tuple[bool, str]: """ Execute query against an in-memory SQLite DB seeded with schema_sql. Returns (success: bool, result_string: str). """ try: conn = sqlite3.connect(":memory:") conn.executescript(schema_sql) cursor = conn.execute(query) rows = cursor.fetchall() col_names = [desc[0] for desc in cursor.description] if cursor.description else [] conn.close() if not rows: return True, "(no rows returned)" # Format as a simple text table header = " | ".join(col_names) sep = "-" * len(header) row_lines = [" | ".join(str(v) for v in row) for row in rows] return True, "\n".join([header, sep] + row_lines) except Exception as e: return False, f"ERROR: {e}" def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool: """Check whether two queries return identical result sets.""" try: conn = sqlite3.connect(":memory:") conn.executescript(schema_sql) rows_a = set(conn.execute(query_a).fetchall()) rows_b = set(conn.execute(query_b).fetchall()) conn.close() return rows_a == rows_b except Exception: return False class SQLTutorEnvironment(Environment[SQLAction, SQLObservation, SQLState]): SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self): super().__init__() self._state = SQLState() def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> SQLObservation: if seed is not None: random.seed(seed) challenge = random.choice(CHALLENGES) state = SQLState( challenge_id=challenge["id"], broken_query=challenge["broken_query"], correct_query=challenge["correct_query"], schema_sql=challenge["schema_sql"], schema_description=challenge["schema_description"], task_description=challenge["task_description"], hints=challenge["hints"], steps_taken=0, max_steps=5, hints_used=0, is_resolved=False, cumulative_reward=0.0, episode_id=episode_id, step_count=0, ) self._state = state # Show the agent the broken query output so it understands what's wrong _, broken_result = _run_query(state.schema_sql, state.broken_query) observation = SQLObservation( broken_query=state.broken_query, schema_description=state.schema_description, task_description=state.task_description, execution_result=f"Current (broken) query output:\n{broken_result}", is_correct=False, hint=None, steps_taken=0, max_steps=state.max_steps, hints_used=0, done=False, reward=None, ) return observation def step( self, action: SQLAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> SQLObservation: state = self._state state.steps_taken += 1 state.step_count += 1 reward = 0.0 done = False hint = None if action.action_type == "request_hint": hint_index = min(state.hints_used, len(state.hints) - 1) hint = state.hints[hint_index] state.hints_used += 1 reward = -0.1 # small penalty for using a hint execution_result = f"Current (broken) query output shown for reference." _, execution_result = _run_query(state.schema_sql, state.broken_query) execution_result = f"(Hint requested - no query executed)\nBroken query output:\n{execution_result}" is_correct = False elif action.action_type == "submit_fix": if not action.sql_query: execution_result = "ERROR: You chose 'submit_fix' but provided no sql_query." is_correct = False reward = -0.05 else: success, execution_result = _run_query(state.schema_sql, action.sql_query) if not success: is_correct = False reward = -0.1 else: is_correct = _results_match( state.schema_sql, action.sql_query, state.correct_query ) if is_correct: # Reward decreases with hints used and steps taken base_reward = 1.0 hint_penalty = 0.15 * state.hints_used step_penalty = 0.05 * max(0, state.steps_taken - 1) reward = max(0.1, base_reward - hint_penalty - step_penalty) state.is_resolved = True done = True else: reward = -0.05 else: execution_result = f"ERROR: Unknown action_type '{action.action_type}'. Use 'submit_fix' or 'request_hint'." is_correct = False reward = -0.05 # End episode if max steps reached if state.steps_taken >= state.max_steps and not done: done = True state.cumulative_reward += reward observation = SQLObservation( broken_query=state.broken_query, schema_description=state.schema_description, task_description=state.task_description, execution_result=execution_result, is_correct=is_correct, hint=hint, steps_taken=state.steps_taken, max_steps=state.max_steps, hints_used=state.hints_used, done=done, reward=reward, ) return observation @property def state(self) -> SQLState: return self._state