from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from models import SQLAction, SQLObservation, SQLState class SQLTutorEnv(EnvClient[SQLAction, SQLObservation, SQLState]): """ Client for the SQL Tutor environment. Usage: # Connect to a running HF Space env = SQLTutorEnv(base_url="https://your-space.hf.space") # Or load locally from Hub env = SQLTutorEnv.from_hub("your-username/sql-tutor-env") obs, state = env.reset() result = env.step(SQLAction(action_type="submit_fix", sql_query="SELECT ...")) """ def __init__(self, base_url: str, **kwargs): super().__init__(base_url=base_url, **kwargs) def _step_payload(self, action: SQLAction) -> Dict: return { "action_type": action.action_type, "sql_query": action.sql_query, } def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]: obs_data = payload.get("observation", {}) observation = SQLObservation( broken_query=obs_data.get("broken_query", ""), schema_description=obs_data.get("schema_description", ""), task_description=obs_data.get("task_description", ""), execution_result=obs_data.get("execution_result", ""), is_correct=obs_data.get("is_correct", False), hint=obs_data.get("hint"), steps_taken=obs_data.get("steps_taken", 0), max_steps=obs_data.get("max_steps", 5), hints_used=obs_data.get("hints_used", 0), ) return StepResult( observation=observation, reward=payload.get("reward", 0.0), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> SQLState: return SQLState( challenge_id=payload.get("challenge_id", ""), broken_query=payload.get("broken_query", ""), correct_query=payload.get("correct_query", ""), schema_sql=payload.get("schema_sql", ""), schema_description=payload.get("schema_description", ""), task_description=payload.get("task_description", ""), hints=payload.get("hints", []), steps_taken=payload.get("steps_taken", 0), max_steps=payload.get("max_steps", 5), hints_used=payload.get("hints_used", 0), is_resolved=payload.get("is_resolved", False), cumulative_reward=payload.get("cumulative_reward", 0.0), )