Spaces:
Sleeping
Sleeping
| """SQLDrift ``EnvClient`` β tool-aware payload constructors + response parser. | |
| Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts, | |
| notebook exploration, and integration tests all use the same WS-backed | |
| session semantics. Stateful episodes MUST go through the ``/ws`` channel | |
| (HTTP ``/step`` is stateless: one fresh env per request). | |
| Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.) | |
| hide the discriminated-union boilerplate so agent code reads naturally:: | |
| env = SqlDriftEnv(base_url="http://localhost:8000").sync() | |
| with env: | |
| r = env.reset(seed=42, scenario_id="03_cartesian_join") | |
| r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events")) | |
| ... | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from models import ( | |
| ConsultDBAPayload, | |
| DescribeTablePayload, | |
| ExplainQueryPayload, | |
| ListTablesPayload, | |
| ReadChangelogPayload, | |
| RunQueryPayload, | |
| SampleRowsPayload, | |
| SqlDriftAction, | |
| SqlDriftObservation, | |
| SqlDriftState, | |
| SubmitRewritePayload, | |
| ToolName, | |
| ) | |
| class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]): | |
| """Tool-aware client for the SQLDrift OpenEnv environment.""" | |
| # ------------------------------------------------------------------ | |
| # EnvClient ABC implementations | |
| # ------------------------------------------------------------------ | |
| def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]: | |
| return action.model_dump(mode="json") | |
| def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]: | |
| obs_data = payload.get("observation", {}) | |
| observation = SqlDriftObservation.model_validate(obs_data) | |
| # Base transport strips reward + done off the observation dict β we | |
| # re-populate them so the agent can read straight off `.observation`. | |
| reward = payload.get("reward") | |
| done = bool(payload.get("done", False)) | |
| observation.reward = reward | |
| observation.done = done | |
| return StepResult(observation=observation, reward=reward, done=done) | |
| def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState: | |
| return SqlDriftState.model_validate(payload) | |
| # ------------------------------------------------------------------ | |
| # Action factories β one per tool, accepting only the args that tool | |
| # cares about; payload.kind is filled in automatically. | |
| # ------------------------------------------------------------------ | |
| def action_list_tables() -> SqlDriftAction: | |
| return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload()) | |
| def action_describe_table(table: str) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.DESCRIBE_TABLE, | |
| payload=DescribeTablePayload(table=table), | |
| ) | |
| def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.SAMPLE_ROWS, | |
| payload=SampleRowsPayload(table=table, limit=limit), | |
| ) | |
| def action_run_query(sql: str) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.RUN_QUERY, | |
| payload=RunQueryPayload(sql=sql), | |
| ) | |
| def action_explain_query(sql: str) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.EXPLAIN_QUERY, | |
| payload=ExplainQueryPayload(sql=sql), | |
| ) | |
| def action_read_changelog() -> SqlDriftAction: | |
| return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload()) | |
| def action_submit_rewrite(sql: str) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.SUBMIT_REWRITE, | |
| payload=SubmitRewritePayload(sql=sql), | |
| ) | |
| def action_consult_dba(question: str) -> SqlDriftAction: | |
| return SqlDriftAction( | |
| tool=ToolName.CONSULT_DBA, | |
| payload=ConsultDBAPayload(question=question), | |
| ) | |
| __all__ = ["SqlDriftEnv"] | |