File size: 2,528 Bytes
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
SQL Migration Environment Client.

Provides the client for connecting to a SQL Migration Environment server.
Extends the base OpenEnv EnvClient for WebSocket-based persistent sessions.

Example:
    >>> from sql_migration_env import DbMigrationEnv
    >>>
    >>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
    >>> with env:
    ...     result = env.reset()
    ...     result = env.step({"sql_command": "ALTER TABLE users RENAME COLUMN first_name TO full_name", "reasoning": "test"})
    ...     print(result.observation)
"""

from typing import Any, Dict

from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult

from .models import MigrationAction, MigrationObservation, MigrationState


class DbMigrationEnv(EnvClient):
    """
    Client for the SQL Migration Environment.

    Inherits connection management, async/sync wrappers, and Docker/HF Space
    support from EnvClient. Provides typed step/reset interactions.

    Example:
        >>> async with DbMigrationEnv(base_url="http://localhost:7860") as env:
        ...     result = await env.reset(task_name="column-restructure")
        ...     while not result.done:
        ...         action = {"sql_command": "...", "reasoning": "...", "submit_final": False}
        ...         result = await env.step(action)
        ...     print(f"Final score: {result.observation.get('migration_progress', 0)}")

    Example with sync wrapper:
        >>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
        >>> with env:
        ...     result = env.reset()
        ...     print(result.observation)
    """

    def _step_payload(self, action: Any) -> Dict[str, Any]:
        """Convert action to JSON payload for the server."""
        if isinstance(action, MigrationAction):
            return action.model_dump()
        elif isinstance(action, dict):
            return action
        else:
            raise ValueError(f"Expected MigrationAction or dict, got {type(action)}")

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
        """Parse server response into StepResult."""
        observation = payload.get("observation", {})
        reward = payload.get("reward")
        done = payload.get("done", False)
        return StepResult(
            observation=observation,
            reward=reward,
            done=done,
        )

    def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Parse state response."""
        return payload