md896's picture
Make OpenEnv training+API judge-proof
d061422
"""
Core SQL Debug Environment.
Manages episode state, delegates to tasks and reward function.
"""
import uuid
import asyncio
from typing import Optional, Dict, Any, List
from .models import (
SQLDebugAction, SQLDebugObservation, SQLDebugReward,
EpisodeState, ActionType, QueryResult, SchemaInfo
)
from .database import EpisodeDatabase
from .reward import compute_reward
from .tasks.task_easy import EasyTask
from .tasks.task_medium import MediumTask, MediumTaskGrader
from .tasks.task_hard import HardTask
from .tasks.task_finance_explosion import FinanceExplosionTask
TASKS = {
"easy_syntax_fix": EasyTask(),
"medium_logic_fix": MediumTask(),
"hard_multi_bug": HardTask(),
"hard_finance_explosion": FinanceExplosionTask(),
}
STRICT_MIN_SCORE = 0.001
class SQLDebugEnv:
"""
The SQL Debug Environment.
Manages one active episode at a time per session.
Thread-safe for concurrent sessions via instance-per-session pattern.
"""
def __init__(self, task_id: str = "easy_syntax_fix"):
self.task_id = task_id
self.task = TASKS[task_id]
self._db: Optional[EpisodeDatabase] = None
self._state: Optional[EpisodeState] = None
self._lock = asyncio.Lock()
async def reset(self) -> tuple[SQLDebugObservation, Dict]:
"""Reset environment to initial state. Returns (observation, info)."""
async with self._lock:
# Close previous DB if exists
if self._db:
self._db.close()
# Fresh DB
self._db = EpisodeDatabase(
task_id=self.task.task_id,
schema_sql=self.task.schema_sql,
seed_data_sql=self.task.seed_data_sql
)
# Fresh state
self._state = EpisodeState(
task_id=self.task.task_id,
task_difficulty=self.task.difficulty,
original_query=self.task.broken_query,
current_query=None,
best_score_so_far=STRICT_MIN_SCORE,
steps_taken=0,
max_steps=self.task.max_steps,
action_history=[],
reward_history=[],
is_done=False,
success=False,
db_schema=self._db.get_schema()
)
obs = SQLDebugObservation(
task_id=self.task.task_id,
task_description=self.task.description,
original_query=self.task.broken_query,
current_query=None,
expected_description=self.task.expected_output_description,
last_action_type="reset",
last_query_result=None,
steps_taken=0,
steps_remaining=self.task.max_steps,
current_score=STRICT_MIN_SCORE,
schema_info=SchemaInfo(tables=self._db.get_schema()),
is_done=False,
success=False
)
return obs, {"task": self.task.to_dict()}
async def step(self, action: SQLDebugAction) -> tuple[SQLDebugObservation, float, bool, Dict]:
"""
Execute one action.
Returns (observation, reward_value, done, info)
"""
async with self._lock:
if self._state is None:
raise RuntimeError("Call reset() before step()")
if self._state.is_done:
raise RuntimeError("Episode is done. Call reset() to start new episode.")
self._state.steps_taken += 1
steps_taken = self._state.steps_taken
query_result_raw = None
prev_best_score = self._state.best_score_so_far
grade_score = self._state.best_score_so_far
schema_info = None
error_details = None
sample_rows = None
hint = None
# --- Execute action ---
if action.action_type == ActionType.SUBMIT_QUERY:
if not action.query:
raise ValueError("query is required for submit_query action")
self._state.current_query = action.query
query_result_raw = self._db.execute_query(action.query)
# Grade the result
actual_rows = query_result_raw.get("rows") if query_result_raw.get("success") else None
# Use custom grader for medium task
if self.task.task_id == "medium_logic_fix":
grade_score = MediumTaskGrader.grade(actual_rows or [])
else:
grade_score = self.task.grade(actual_rows)
if grade_score > self._state.best_score_so_far:
self._state.best_score_so_far = grade_score
elif action.action_type == ActionType.INSPECT_SCHEMA:
schema = self._db.get_schema()
schema_info = SchemaInfo(tables=schema)
grade_score = self._state.best_score_so_far
elif action.action_type == ActionType.INSPECT_ERROR:
# Return last error if available
if self._state.action_history:
last = self._state.action_history[-1]
error_details = last.get("error_message", "No error recorded from last query.")
else:
error_details = "No query has been submitted yet."
grade_score = self._state.best_score_so_far
elif action.action_type == ActionType.INSPECT_SAMPLE:
if not action.table_name:
raise ValueError("table_name required for inspect_sample")
sample_rows = self._db.get_sample_rows(action.table_name)
grade_score = self._state.best_score_so_far
elif action.action_type == ActionType.RESET_QUERY:
self._state.current_query = self.task.broken_query
grade_score = self._state.best_score_so_far
# --- Compute reward ---
schema_tables = list(self._db.get_schema().keys())
reward_obj = compute_reward(
action_type=action.action_type.value,
query_result=query_result_raw,
grade_score=grade_score,
steps_taken=steps_taken,
max_steps=self.task.max_steps,
previous_best_score=prev_best_score,
schema_tables=schema_tables,
submitted_query=action.query if action.action_type == ActionType.SUBMIT_QUERY else None
)
# --- Check done conditions ---
is_done = False
success = False
if grade_score >= 0.95:
is_done = True
success = True
elif steps_taken >= self.task.max_steps:
is_done = True
success = self._state.best_score_so_far >= 0.5
self._state.is_done = is_done
self._state.success = success
# --- Hint logic ---
hint_threshold = 3 if self.task.difficulty == "easy" else 5
if steps_taken >= hint_threshold:
hint = self.task.hint
# --- Record history ---
self._state.action_history.append({
"step": steps_taken,
"action_type": action.action_type.value,
"query": action.query,
"grade_score": grade_score,
"reward": reward_obj.value,
"error_message": query_result_raw.get("error_message") if query_result_raw else None
})
self._state.reward_history.append(reward_obj.value)
# --- Build observation ---
qr = QueryResult(**query_result_raw) if query_result_raw else None
obs = SQLDebugObservation(
task_id=self.task.task_id,
task_description=self.task.description,
original_query=self.task.broken_query,
current_query=self._state.current_query,
expected_description=self.task.expected_output_description,
last_action_type=action.action_type.value,
last_query_result=qr,
steps_taken=steps_taken,
steps_remaining=max(0, self.task.max_steps - steps_taken),
current_score=self._state.best_score_so_far,
schema_info=schema_info,
error_details=error_details,
sample_rows=sample_rows,
hint=hint,
is_done=is_done,
success=success
)
return obs, reward_obj.value, is_done, {
"grade_score": grade_score,
"reward_breakdown": reward_obj.breakdown,
"success": success,
"steps_taken": steps_taken
}
def to_observation(
self,
*,
last_action_type: str,
last_query_result: Optional[QueryResult] = None,
schema_info: Optional[SchemaInfo] = None,
error_details: Optional[str] = None,
sample_rows: Optional[List[Dict[str, Any]]] = None,
hint: Optional[str] = None,
) -> SQLDebugObservation:
"""
Build an observation from the current state without mutating the episode.
Useful for endpoints that want to return an observation (e.g. reviewer rejection)
without actually executing an action.
"""
if self._state is None:
raise RuntimeError("Call reset() first")
return SQLDebugObservation(
task_id=self.task.task_id,
task_description=self.task.description,
original_query=self.task.broken_query,
current_query=self._state.current_query,
expected_description=self.task.expected_output_description,
last_action_type=last_action_type,
last_query_result=last_query_result,
steps_taken=self._state.steps_taken,
steps_remaining=max(0, self.task.max_steps - self._state.steps_taken),
current_score=self._state.best_score_so_far,
schema_info=schema_info,
error_details=error_details,
sample_rows=sample_rows,
hint=hint,
is_done=self._state.is_done,
success=self._state.success,
)
def get_state(self) -> EpisodeState:
if self._state is None:
raise RuntimeError("Call reset() first")
return self._state
def close(self):
if self._db:
self._db.close()
self._db = None