Spaces:
Sleeping
Sleeping
File size: 2,528 Bytes
1b42f19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """
SQL Migration Environment Client.
Provides the client for connecting to a SQL Migration Environment server.
Extends the base OpenEnv EnvClient for WebSocket-based persistent sessions.
Example:
>>> from sql_migration_env import DbMigrationEnv
>>>
>>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
>>> with env:
... result = env.reset()
... result = env.step({"sql_command": "ALTER TABLE users RENAME COLUMN first_name TO full_name", "reasoning": "test"})
... print(result.observation)
"""
from typing import Any, Dict
from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult
from .models import MigrationAction, MigrationObservation, MigrationState
class DbMigrationEnv(EnvClient):
"""
Client for the SQL Migration Environment.
Inherits connection management, async/sync wrappers, and Docker/HF Space
support from EnvClient. Provides typed step/reset interactions.
Example:
>>> async with DbMigrationEnv(base_url="http://localhost:7860") as env:
... result = await env.reset(task_name="column-restructure")
... while not result.done:
... action = {"sql_command": "...", "reasoning": "...", "submit_final": False}
... result = await env.step(action)
... print(f"Final score: {result.observation.get('migration_progress', 0)}")
Example with sync wrapper:
>>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
>>> with env:
... result = env.reset()
... print(result.observation)
"""
def _step_payload(self, action: Any) -> Dict[str, Any]:
"""Convert action to JSON payload for the server."""
if isinstance(action, MigrationAction):
return action.model_dump()
elif isinstance(action, dict):
return action
else:
raise ValueError(f"Expected MigrationAction or dict, got {type(action)}")
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
"""Parse server response into StepResult."""
observation = payload.get("observation", {})
reward = payload.get("reward")
done = payload.get("done", False)
return StepResult(
observation=observation,
reward=reward,
done=done,
)
def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Parse state response."""
return payload
|