sql-drift-env / client.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""SQLDrift ``EnvClient`` β€” tool-aware payload constructors + response parser.
Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts,
notebook exploration, and integration tests all use the same WS-backed
session semantics. Stateful episodes MUST go through the ``/ws`` channel
(HTTP ``/step`` is stateless: one fresh env per request).
Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.)
hide the discriminated-union boilerplate so agent code reads naturally::
env = SqlDriftEnv(base_url="http://localhost:8000").sync()
with env:
r = env.reset(seed=42, scenario_id="03_cartesian_join")
r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events"))
...
"""
from __future__ import annotations
from typing import Any
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from models import (
ConsultDBAPayload,
DescribeTablePayload,
ExplainQueryPayload,
ListTablesPayload,
ReadChangelogPayload,
RunQueryPayload,
SampleRowsPayload,
SqlDriftAction,
SqlDriftObservation,
SqlDriftState,
SubmitRewritePayload,
ToolName,
)
class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
"""Tool-aware client for the SQLDrift OpenEnv environment."""
# ------------------------------------------------------------------
# EnvClient ABC implementations
# ------------------------------------------------------------------
def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]:
return action.model_dump(mode="json")
def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]:
obs_data = payload.get("observation", {})
observation = SqlDriftObservation.model_validate(obs_data)
# Base transport strips reward + done off the observation dict β€” we
# re-populate them so the agent can read straight off `.observation`.
reward = payload.get("reward")
done = bool(payload.get("done", False))
observation.reward = reward
observation.done = done
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState:
return SqlDriftState.model_validate(payload)
# ------------------------------------------------------------------
# Action factories β€” one per tool, accepting only the args that tool
# cares about; payload.kind is filled in automatically.
# ------------------------------------------------------------------
@staticmethod
def action_list_tables() -> SqlDriftAction:
return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
@staticmethod
def action_describe_table(table: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.DESCRIBE_TABLE,
payload=DescribeTablePayload(table=table),
)
@staticmethod
def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.SAMPLE_ROWS,
payload=SampleRowsPayload(table=table, limit=limit),
)
@staticmethod
def action_run_query(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.RUN_QUERY,
payload=RunQueryPayload(sql=sql),
)
@staticmethod
def action_explain_query(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.EXPLAIN_QUERY,
payload=ExplainQueryPayload(sql=sql),
)
@staticmethod
def action_read_changelog() -> SqlDriftAction:
return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload())
@staticmethod
def action_submit_rewrite(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.SUBMIT_REWRITE,
payload=SubmitRewritePayload(sql=sql),
)
@staticmethod
def action_consult_dba(question: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.CONSULT_DBA,
payload=ConsultDBAPayload(question=question),
)
__all__ = ["SqlDriftEnv"]