sql_tutor_env / client.py
snigenigmatic's picture
Upload folder using huggingface_hub
0683cf4 verified
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from models import SQLAction, SQLObservation, SQLState
class SQLTutorEnv(EnvClient[SQLAction, SQLObservation, SQLState]):
"""
Client for the SQL Tutor environment.
Usage:
# Connect to a running HF Space
env = SQLTutorEnv(base_url="https://your-space.hf.space")
# Or load locally from Hub
env = SQLTutorEnv.from_hub("your-username/sql-tutor-env")
obs, state = env.reset()
result = env.step(SQLAction(action_type="submit_fix", sql_query="SELECT ..."))
"""
def __init__(self, base_url: str, **kwargs):
super().__init__(base_url=base_url, **kwargs)
def _step_payload(self, action: SQLAction) -> Dict:
return {
"action_type": action.action_type,
"sql_query": action.sql_query,
}
def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
obs_data = payload.get("observation", {})
observation = SQLObservation(
broken_query=obs_data.get("broken_query", ""),
schema_description=obs_data.get("schema_description", ""),
task_description=obs_data.get("task_description", ""),
execution_result=obs_data.get("execution_result", ""),
is_correct=obs_data.get("is_correct", False),
hint=obs_data.get("hint"),
steps_taken=obs_data.get("steps_taken", 0),
max_steps=obs_data.get("max_steps", 5),
hints_used=obs_data.get("hints_used", 0),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> SQLState:
return SQLState(
challenge_id=payload.get("challenge_id", ""),
broken_query=payload.get("broken_query", ""),
correct_query=payload.get("correct_query", ""),
schema_sql=payload.get("schema_sql", ""),
schema_description=payload.get("schema_description", ""),
task_description=payload.get("task_description", ""),
hints=payload.get("hints", []),
steps_taken=payload.get("steps_taken", 0),
max_steps=payload.get("max_steps", 5),
hints_used=payload.get("hints_used", 0),
is_resolved=payload.get("is_resolved", False),
cumulative_reward=payload.get("cumulative_reward", 0.0),
)