""" Core SQL Debug Environment. Manages episode state, delegates to tasks and reward function. """ import uuid import asyncio from typing import Optional, Dict, Any, List from .models import ( SQLDebugAction, SQLDebugObservation, SQLDebugReward, EpisodeState, ActionType, QueryResult, SchemaInfo ) from .database import EpisodeDatabase from .reward import compute_reward from .tasks.task_easy import EasyTask from .tasks.task_medium import MediumTask, MediumTaskGrader from .tasks.task_hard import HardTask from .tasks.task_finance_explosion import FinanceExplosionTask TASKS = { "easy_syntax_fix": EasyTask(), "medium_logic_fix": MediumTask(), "hard_multi_bug": HardTask(), "hard_finance_explosion": FinanceExplosionTask(), } STRICT_MIN_SCORE = 0.001 class SQLDebugEnv: """ The SQL Debug Environment. Manages one active episode at a time per session. Thread-safe for concurrent sessions via instance-per-session pattern. """ def __init__(self, task_id: str = "easy_syntax_fix"): self.task_id = task_id self.task = TASKS[task_id] self._db: Optional[EpisodeDatabase] = None self._state: Optional[EpisodeState] = None self._lock = asyncio.Lock() async def reset(self) -> tuple[SQLDebugObservation, Dict]: """Reset environment to initial state. Returns (observation, info).""" async with self._lock: # Close previous DB if exists if self._db: self._db.close() # Fresh DB self._db = EpisodeDatabase( task_id=self.task.task_id, schema_sql=self.task.schema_sql, seed_data_sql=self.task.seed_data_sql ) # Fresh state self._state = EpisodeState( task_id=self.task.task_id, task_difficulty=self.task.difficulty, original_query=self.task.broken_query, current_query=None, best_score_so_far=STRICT_MIN_SCORE, steps_taken=0, max_steps=self.task.max_steps, action_history=[], reward_history=[], is_done=False, success=False, db_schema=self._db.get_schema() ) obs = SQLDebugObservation( task_id=self.task.task_id, task_description=self.task.description, original_query=self.task.broken_query, current_query=None, expected_description=self.task.expected_output_description, last_action_type="reset", last_query_result=None, steps_taken=0, steps_remaining=self.task.max_steps, current_score=STRICT_MIN_SCORE, schema_info=SchemaInfo(tables=self._db.get_schema()), is_done=False, success=False ) return obs, {"task": self.task.to_dict()} async def step(self, action: SQLDebugAction) -> tuple[SQLDebugObservation, float, bool, Dict]: """ Execute one action. Returns (observation, reward_value, done, info) """ async with self._lock: if self._state is None: raise RuntimeError("Call reset() before step()") if self._state.is_done: raise RuntimeError("Episode is done. Call reset() to start new episode.") self._state.steps_taken += 1 steps_taken = self._state.steps_taken query_result_raw = None prev_best_score = self._state.best_score_so_far grade_score = self._state.best_score_so_far schema_info = None error_details = None sample_rows = None hint = None # --- Execute action --- if action.action_type == ActionType.SUBMIT_QUERY: if not action.query: raise ValueError("query is required for submit_query action") self._state.current_query = action.query query_result_raw = self._db.execute_query(action.query) # Grade the result actual_rows = query_result_raw.get("rows") if query_result_raw.get("success") else None # Use custom grader for medium task if self.task.task_id == "medium_logic_fix": grade_score = MediumTaskGrader.grade(actual_rows or []) else: grade_score = self.task.grade(actual_rows) if grade_score > self._state.best_score_so_far: self._state.best_score_so_far = grade_score elif action.action_type == ActionType.INSPECT_SCHEMA: schema = self._db.get_schema() schema_info = SchemaInfo(tables=schema) grade_score = self._state.best_score_so_far elif action.action_type == ActionType.INSPECT_ERROR: # Return last error if available if self._state.action_history: last = self._state.action_history[-1] error_details = last.get("error_message", "No error recorded from last query.") else: error_details = "No query has been submitted yet." grade_score = self._state.best_score_so_far elif action.action_type == ActionType.INSPECT_SAMPLE: if not action.table_name: raise ValueError("table_name required for inspect_sample") sample_rows = self._db.get_sample_rows(action.table_name) grade_score = self._state.best_score_so_far elif action.action_type == ActionType.RESET_QUERY: self._state.current_query = self.task.broken_query grade_score = self._state.best_score_so_far # --- Compute reward --- schema_tables = list(self._db.get_schema().keys()) reward_obj = compute_reward( action_type=action.action_type.value, query_result=query_result_raw, grade_score=grade_score, steps_taken=steps_taken, max_steps=self.task.max_steps, previous_best_score=prev_best_score, schema_tables=schema_tables, submitted_query=action.query if action.action_type == ActionType.SUBMIT_QUERY else None ) # --- Check done conditions --- is_done = False success = False if grade_score >= 0.95: is_done = True success = True elif steps_taken >= self.task.max_steps: is_done = True success = self._state.best_score_so_far >= 0.5 self._state.is_done = is_done self._state.success = success # --- Hint logic --- hint_threshold = 3 if self.task.difficulty == "easy" else 5 if steps_taken >= hint_threshold: hint = self.task.hint # --- Record history --- self._state.action_history.append({ "step": steps_taken, "action_type": action.action_type.value, "query": action.query, "grade_score": grade_score, "reward": reward_obj.value, "error_message": query_result_raw.get("error_message") if query_result_raw else None }) self._state.reward_history.append(reward_obj.value) # --- Build observation --- qr = QueryResult(**query_result_raw) if query_result_raw else None obs = SQLDebugObservation( task_id=self.task.task_id, task_description=self.task.description, original_query=self.task.broken_query, current_query=self._state.current_query, expected_description=self.task.expected_output_description, last_action_type=action.action_type.value, last_query_result=qr, steps_taken=steps_taken, steps_remaining=max(0, self.task.max_steps - steps_taken), current_score=self._state.best_score_so_far, schema_info=schema_info, error_details=error_details, sample_rows=sample_rows, hint=hint, is_done=is_done, success=success ) return obs, reward_obj.value, is_done, { "grade_score": grade_score, "reward_breakdown": reward_obj.breakdown, "success": success, "steps_taken": steps_taken } def to_observation( self, *, last_action_type: str, last_query_result: Optional[QueryResult] = None, schema_info: Optional[SchemaInfo] = None, error_details: Optional[str] = None, sample_rows: Optional[List[Dict[str, Any]]] = None, hint: Optional[str] = None, ) -> SQLDebugObservation: """ Build an observation from the current state without mutating the episode. Useful for endpoints that want to return an observation (e.g. reviewer rejection) without actually executing an action. """ if self._state is None: raise RuntimeError("Call reset() first") return SQLDebugObservation( task_id=self.task.task_id, task_description=self.task.description, original_query=self.task.broken_query, current_query=self._state.current_query, expected_description=self.task.expected_output_description, last_action_type=last_action_type, last_query_result=last_query_result, steps_taken=self._state.steps_taken, steps_remaining=max(0, self.task.max_steps - self._state.steps_taken), current_score=self._state.best_score_so_far, schema_info=schema_info, error_details=error_details, sample_rows=sample_rows, hint=hint, is_done=self._state.is_done, success=self._state.success, ) def get_state(self) -> EpisodeState: if self._state is None: raise RuntimeError("Call reset() first") return self._state def close(self): if self._db: self._db.close() self._db = None