sql_tutor_env / server /sql_environment.py
snigenigmatic's picture
Upload folder using huggingface_hub
2e6e0b2 verified
import sqlite3
import random
from typing import Any, Optional, Tuple
from openenv.core.env_server.interfaces import Environment
from models import SQLAction, SQLObservation, SQLState
from server.challenges import CHALLENGES
def _run_query(schema_sql: str, query: str) -> Tuple[bool, str]:
"""
Execute query against an in-memory SQLite DB seeded with schema_sql.
Returns (success: bool, result_string: str).
"""
try:
conn = sqlite3.connect(":memory:")
conn.executescript(schema_sql)
cursor = conn.execute(query)
rows = cursor.fetchall()
col_names = [desc[0] for desc in cursor.description] if cursor.description else []
conn.close()
if not rows:
return True, "(no rows returned)"
# Format as a simple text table
header = " | ".join(col_names)
sep = "-" * len(header)
row_lines = [" | ".join(str(v) for v in row) for row in rows]
return True, "\n".join([header, sep] + row_lines)
except Exception as e:
return False, f"ERROR: {e}"
def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool:
"""Check whether two queries return identical result sets."""
try:
conn = sqlite3.connect(":memory:")
conn.executescript(schema_sql)
rows_a = set(conn.execute(query_a).fetchall())
rows_b = set(conn.execute(query_b).fetchall())
conn.close()
return rows_a == rows_b
except Exception:
return False
class SQLTutorEnvironment(Environment[SQLAction, SQLObservation, SQLState]):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self):
super().__init__()
self._state = SQLState()
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> SQLObservation:
if seed is not None:
random.seed(seed)
challenge = random.choice(CHALLENGES)
state = SQLState(
challenge_id=challenge["id"],
broken_query=challenge["broken_query"],
correct_query=challenge["correct_query"],
schema_sql=challenge["schema_sql"],
schema_description=challenge["schema_description"],
task_description=challenge["task_description"],
hints=challenge["hints"],
steps_taken=0,
max_steps=5,
hints_used=0,
is_resolved=False,
cumulative_reward=0.0,
episode_id=episode_id,
step_count=0,
)
self._state = state
# Show the agent the broken query output so it understands what's wrong
_, broken_result = _run_query(state.schema_sql, state.broken_query)
observation = SQLObservation(
broken_query=state.broken_query,
schema_description=state.schema_description,
task_description=state.task_description,
execution_result=f"Current (broken) query output:\n{broken_result}",
is_correct=False,
hint=None,
steps_taken=0,
max_steps=state.max_steps,
hints_used=0,
done=False,
reward=None,
)
return observation
def step(
self,
action: SQLAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> SQLObservation:
state = self._state
state.steps_taken += 1
state.step_count += 1
reward = 0.0
done = False
hint = None
if action.action_type == "request_hint":
hint_index = min(state.hints_used, len(state.hints) - 1)
hint = state.hints[hint_index]
state.hints_used += 1
reward = -0.1 # small penalty for using a hint
execution_result = f"Current (broken) query output shown for reference."
_, execution_result = _run_query(state.schema_sql, state.broken_query)
execution_result = f"(Hint requested - no query executed)\nBroken query output:\n{execution_result}"
is_correct = False
elif action.action_type == "submit_fix":
if not action.sql_query:
execution_result = "ERROR: You chose 'submit_fix' but provided no sql_query."
is_correct = False
reward = -0.05
else:
success, execution_result = _run_query(state.schema_sql, action.sql_query)
if not success:
is_correct = False
reward = -0.1
else:
is_correct = _results_match(
state.schema_sql, action.sql_query, state.correct_query
)
if is_correct:
# Reward decreases with hints used and steps taken
base_reward = 1.0
hint_penalty = 0.15 * state.hints_used
step_penalty = 0.05 * max(0, state.steps_taken - 1)
reward = max(0.1, base_reward - hint_penalty - step_penalty)
state.is_resolved = True
done = True
else:
reward = -0.05
else:
execution_result = f"ERROR: Unknown action_type '{action.action_type}'. Use 'submit_fix' or 'request_hint'."
is_correct = False
reward = -0.05
# End episode if max steps reached
if state.steps_taken >= state.max_steps and not done:
done = True
state.cumulative_reward += reward
observation = SQLObservation(
broken_query=state.broken_query,
schema_description=state.schema_description,
task_description=state.task_description,
execution_result=execution_result,
is_correct=is_correct,
hint=hint,
steps_taken=state.steps_taken,
max_steps=state.max_steps,
hints_used=state.hints_used,
done=done,
reward=reward,
)
return observation
@property
def state(self) -> SQLState:
return self._state