Spaces:
Sleeping
Sleeping
File size: 4,277 Bytes
5850885 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """SQLDrift ``EnvClient`` — tool-aware payload constructors + response parser.
Inherits :class:`openenv.core.env_client.EnvClient` so TRL rollouts,
notebook exploration, and integration tests all use the same WS-backed
session semantics. Stateful episodes MUST go through the ``/ws`` channel
(HTTP ``/step`` is stateless: one fresh env per request).
Convenience constructors (:meth:`SqlDriftEnv.action_list_tables`, etc.)
hide the discriminated-union boilerplate so agent code reads naturally::
env = SqlDriftEnv(base_url="http://localhost:8000").sync()
with env:
r = env.reset(seed=42, scenario_id="03_cartesian_join")
r = env.step(SqlDriftEnv.action_run_query("SELECT COUNT(*) FROM events"))
...
"""
from __future__ import annotations
from typing import Any
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from models import (
ConsultDBAPayload,
DescribeTablePayload,
ExplainQueryPayload,
ListTablesPayload,
ReadChangelogPayload,
RunQueryPayload,
SampleRowsPayload,
SqlDriftAction,
SqlDriftObservation,
SqlDriftState,
SubmitRewritePayload,
ToolName,
)
class SqlDriftEnv(EnvClient[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
"""Tool-aware client for the SQLDrift OpenEnv environment."""
# ------------------------------------------------------------------
# EnvClient ABC implementations
# ------------------------------------------------------------------
def _step_payload(self, action: SqlDriftAction) -> dict[str, Any]:
return action.model_dump(mode="json")
def _parse_result(self, payload: dict[str, Any]) -> StepResult[SqlDriftObservation]:
obs_data = payload.get("observation", {})
observation = SqlDriftObservation.model_validate(obs_data)
# Base transport strips reward + done off the observation dict — we
# re-populate them so the agent can read straight off `.observation`.
reward = payload.get("reward")
done = bool(payload.get("done", False))
observation.reward = reward
observation.done = done
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: dict[str, Any]) -> SqlDriftState:
return SqlDriftState.model_validate(payload)
# ------------------------------------------------------------------
# Action factories — one per tool, accepting only the args that tool
# cares about; payload.kind is filled in automatically.
# ------------------------------------------------------------------
@staticmethod
def action_list_tables() -> SqlDriftAction:
return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
@staticmethod
def action_describe_table(table: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.DESCRIBE_TABLE,
payload=DescribeTablePayload(table=table),
)
@staticmethod
def action_sample_rows(table: str, limit: int = 5) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.SAMPLE_ROWS,
payload=SampleRowsPayload(table=table, limit=limit),
)
@staticmethod
def action_run_query(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.RUN_QUERY,
payload=RunQueryPayload(sql=sql),
)
@staticmethod
def action_explain_query(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.EXPLAIN_QUERY,
payload=ExplainQueryPayload(sql=sql),
)
@staticmethod
def action_read_changelog() -> SqlDriftAction:
return SqlDriftAction(tool=ToolName.READ_CHANGELOG, payload=ReadChangelogPayload())
@staticmethod
def action_submit_rewrite(sql: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.SUBMIT_REWRITE,
payload=SubmitRewritePayload(sql=sql),
)
@staticmethod
def action_consult_dba(question: str) -> SqlDriftAction:
return SqlDriftAction(
tool=ToolName.CONSULT_DBA,
payload=ConsultDBAPayload(question=question),
)
__all__ = ["SqlDriftEnv"]
|