Spaces:
Running
Running
| """ | |
| Core SQL Debug Environment. | |
| Manages episode state, delegates to tasks and reward function. | |
| """ | |
| import uuid | |
| import asyncio | |
| from typing import Optional, Dict, Any, List | |
| from .models import ( | |
| SQLDebugAction, SQLDebugObservation, SQLDebugReward, | |
| EpisodeState, ActionType, QueryResult, SchemaInfo | |
| ) | |
| from .database import EpisodeDatabase | |
| from .reward import compute_reward | |
| from .tasks.task_easy import EasyTask | |
| from .tasks.task_medium import MediumTask, MediumTaskGrader | |
| from .tasks.task_hard import HardTask | |
| from .tasks.task_finance_explosion import FinanceExplosionTask | |
| TASKS = { | |
| "easy_syntax_fix": EasyTask(), | |
| "medium_logic_fix": MediumTask(), | |
| "hard_multi_bug": HardTask(), | |
| "hard_finance_explosion": FinanceExplosionTask(), | |
| } | |
| STRICT_MIN_SCORE = 0.001 | |
| class SQLDebugEnv: | |
| """ | |
| The SQL Debug Environment. | |
| Manages one active episode at a time per session. | |
| Thread-safe for concurrent sessions via instance-per-session pattern. | |
| """ | |
| def __init__(self, task_id: str = "easy_syntax_fix"): | |
| self.task_id = task_id | |
| self.task = TASKS[task_id] | |
| self._db: Optional[EpisodeDatabase] = None | |
| self._state: Optional[EpisodeState] = None | |
| self._lock = asyncio.Lock() | |
| async def reset(self) -> tuple[SQLDebugObservation, Dict]: | |
| """Reset environment to initial state. Returns (observation, info).""" | |
| async with self._lock: | |
| # Close previous DB if exists | |
| if self._db: | |
| self._db.close() | |
| # Fresh DB | |
| self._db = EpisodeDatabase( | |
| task_id=self.task.task_id, | |
| schema_sql=self.task.schema_sql, | |
| seed_data_sql=self.task.seed_data_sql | |
| ) | |
| # Fresh state | |
| self._state = EpisodeState( | |
| task_id=self.task.task_id, | |
| task_difficulty=self.task.difficulty, | |
| original_query=self.task.broken_query, | |
| current_query=None, | |
| best_score_so_far=STRICT_MIN_SCORE, | |
| steps_taken=0, | |
| max_steps=self.task.max_steps, | |
| action_history=[], | |
| reward_history=[], | |
| is_done=False, | |
| success=False, | |
| db_schema=self._db.get_schema() | |
| ) | |
| obs = SQLDebugObservation( | |
| task_id=self.task.task_id, | |
| task_description=self.task.description, | |
| original_query=self.task.broken_query, | |
| current_query=None, | |
| expected_description=self.task.expected_output_description, | |
| last_action_type="reset", | |
| last_query_result=None, | |
| steps_taken=0, | |
| steps_remaining=self.task.max_steps, | |
| current_score=STRICT_MIN_SCORE, | |
| schema_info=SchemaInfo(tables=self._db.get_schema()), | |
| is_done=False, | |
| success=False | |
| ) | |
| return obs, {"task": self.task.to_dict()} | |
| async def step(self, action: SQLDebugAction) -> tuple[SQLDebugObservation, float, bool, Dict]: | |
| """ | |
| Execute one action. | |
| Returns (observation, reward_value, done, info) | |
| """ | |
| async with self._lock: | |
| if self._state is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self._state.is_done: | |
| raise RuntimeError("Episode is done. Call reset() to start new episode.") | |
| self._state.steps_taken += 1 | |
| steps_taken = self._state.steps_taken | |
| query_result_raw = None | |
| prev_best_score = self._state.best_score_so_far | |
| grade_score = self._state.best_score_so_far | |
| schema_info = None | |
| error_details = None | |
| sample_rows = None | |
| hint = None | |
| # --- Execute action --- | |
| if action.action_type == ActionType.SUBMIT_QUERY: | |
| if not action.query: | |
| raise ValueError("query is required for submit_query action") | |
| self._state.current_query = action.query | |
| query_result_raw = self._db.execute_query(action.query) | |
| # Grade the result | |
| actual_rows = query_result_raw.get("rows") if query_result_raw.get("success") else None | |
| # Use custom grader for medium task | |
| if self.task.task_id == "medium_logic_fix": | |
| grade_score = MediumTaskGrader.grade(actual_rows or []) | |
| else: | |
| grade_score = self.task.grade(actual_rows) | |
| if grade_score > self._state.best_score_so_far: | |
| self._state.best_score_so_far = grade_score | |
| elif action.action_type == ActionType.INSPECT_SCHEMA: | |
| schema = self._db.get_schema() | |
| schema_info = SchemaInfo(tables=schema) | |
| grade_score = self._state.best_score_so_far | |
| elif action.action_type == ActionType.INSPECT_ERROR: | |
| # Return last error if available | |
| if self._state.action_history: | |
| last = self._state.action_history[-1] | |
| error_details = last.get("error_message", "No error recorded from last query.") | |
| else: | |
| error_details = "No query has been submitted yet." | |
| grade_score = self._state.best_score_so_far | |
| elif action.action_type == ActionType.INSPECT_SAMPLE: | |
| if not action.table_name: | |
| raise ValueError("table_name required for inspect_sample") | |
| sample_rows = self._db.get_sample_rows(action.table_name) | |
| grade_score = self._state.best_score_so_far | |
| elif action.action_type == ActionType.RESET_QUERY: | |
| self._state.current_query = self.task.broken_query | |
| grade_score = self._state.best_score_so_far | |
| # --- Compute reward --- | |
| schema_tables = list(self._db.get_schema().keys()) | |
| reward_obj = compute_reward( | |
| action_type=action.action_type.value, | |
| query_result=query_result_raw, | |
| grade_score=grade_score, | |
| steps_taken=steps_taken, | |
| max_steps=self.task.max_steps, | |
| previous_best_score=prev_best_score, | |
| schema_tables=schema_tables, | |
| submitted_query=action.query if action.action_type == ActionType.SUBMIT_QUERY else None | |
| ) | |
| # --- Check done conditions --- | |
| is_done = False | |
| success = False | |
| if grade_score >= 0.95: | |
| is_done = True | |
| success = True | |
| elif steps_taken >= self.task.max_steps: | |
| is_done = True | |
| success = self._state.best_score_so_far >= 0.5 | |
| self._state.is_done = is_done | |
| self._state.success = success | |
| # --- Hint logic --- | |
| hint_threshold = 3 if self.task.difficulty == "easy" else 5 | |
| if steps_taken >= hint_threshold: | |
| hint = self.task.hint | |
| # --- Record history --- | |
| self._state.action_history.append({ | |
| "step": steps_taken, | |
| "action_type": action.action_type.value, | |
| "query": action.query, | |
| "grade_score": grade_score, | |
| "reward": reward_obj.value, | |
| "error_message": query_result_raw.get("error_message") if query_result_raw else None | |
| }) | |
| self._state.reward_history.append(reward_obj.value) | |
| # --- Build observation --- | |
| qr = QueryResult(**query_result_raw) if query_result_raw else None | |
| obs = SQLDebugObservation( | |
| task_id=self.task.task_id, | |
| task_description=self.task.description, | |
| original_query=self.task.broken_query, | |
| current_query=self._state.current_query, | |
| expected_description=self.task.expected_output_description, | |
| last_action_type=action.action_type.value, | |
| last_query_result=qr, | |
| steps_taken=steps_taken, | |
| steps_remaining=max(0, self.task.max_steps - steps_taken), | |
| current_score=self._state.best_score_so_far, | |
| schema_info=schema_info, | |
| error_details=error_details, | |
| sample_rows=sample_rows, | |
| hint=hint, | |
| is_done=is_done, | |
| success=success | |
| ) | |
| return obs, reward_obj.value, is_done, { | |
| "grade_score": grade_score, | |
| "reward_breakdown": reward_obj.breakdown, | |
| "success": success, | |
| "steps_taken": steps_taken | |
| } | |
| def to_observation( | |
| self, | |
| *, | |
| last_action_type: str, | |
| last_query_result: Optional[QueryResult] = None, | |
| schema_info: Optional[SchemaInfo] = None, | |
| error_details: Optional[str] = None, | |
| sample_rows: Optional[List[Dict[str, Any]]] = None, | |
| hint: Optional[str] = None, | |
| ) -> SQLDebugObservation: | |
| """ | |
| Build an observation from the current state without mutating the episode. | |
| Useful for endpoints that want to return an observation (e.g. reviewer rejection) | |
| without actually executing an action. | |
| """ | |
| if self._state is None: | |
| raise RuntimeError("Call reset() first") | |
| return SQLDebugObservation( | |
| task_id=self.task.task_id, | |
| task_description=self.task.description, | |
| original_query=self.task.broken_query, | |
| current_query=self._state.current_query, | |
| expected_description=self.task.expected_output_description, | |
| last_action_type=last_action_type, | |
| last_query_result=last_query_result, | |
| steps_taken=self._state.steps_taken, | |
| steps_remaining=max(0, self.task.max_steps - self._state.steps_taken), | |
| current_score=self._state.best_score_so_far, | |
| schema_info=schema_info, | |
| error_details=error_details, | |
| sample_rows=sample_rows, | |
| hint=hint, | |
| is_done=self._state.is_done, | |
| success=self._state.success, | |
| ) | |
| def get_state(self) -> EpisodeState: | |
| if self._state is None: | |
| raise RuntimeError("Call reset() first") | |
| return self._state | |
| def close(self): | |
| if self._db: | |
| self._db.close() | |
| self._db = None | |