File size: 4,380 Bytes
1b42f19
 
 
 
 
 
 
 
 
 
 
 
05c4751
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c4751
 
 
 
 
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c4751
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c4751
 
 
 
41cae03
 
 
 
1b42f19
 
 
 
 
 
f294208
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Data models for SQL Migration Environment.

Defines the Action, Observation, and State types used by the environment,
client, and server. All models extend the OpenEnv base types.
"""

from __future__ import annotations

from typing import Any, Dict, Optional

from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, field_validator


class MigrationAction(Action):
    """
    Action for the SQL Migration Environment.

    The agent sends a SQL command to execute against the database,
    along with reasoning for chain-of-thought logging.

    Attributes:
        sql_command: Raw SQL statement to execute (e.g., ALTER TABLE, UPDATE, CREATE).
        reasoning: Free-form explanation of why the agent chose this action.
                   Logged in metadata for Phase 3 human review but not used by grader.
        submit_final: When True, signals the agent believes migration is complete.
                      Triggers final grading and ends the episode.
    """

    sql_command: str = Field(
        description="The raw SQL statement to execute against the database"
    )
    reasoning: str = Field(
        default="",
        description="Chain-of-thought explanation for this action"
    )
    submit_final: bool = Field(
        default=False,
        description="Set to true when you believe the migration is complete"
    )

    @field_validator("sql_command")
    @classmethod
    def strip_whitespace(cls, v: str) -> str:
        return v.strip()


class MigrationObservation(Observation):
    """
    Observation from the SQL Migration Environment.

    Returned after every reset() and step() call. Contains everything
    the agent needs to decide its next action.

    Inherits from Observation:
        done: bool — Whether the episode has terminated
        reward: float | None — Step reward (delta from previous score)
        metadata: dict — Additional metadata

    Attributes:
        current_schema_sql: Current database DDL from sqlite_master.
        target_schema_sql: Target database DDL the agent must achieve.
        last_execution_result: Result of the last SQL execution or error message.
        step_number: Current step count (0 after reset, increments each step).
        migration_progress: Current grader score from 0.0 to 1.0.
        task_name: Name of the current task being attempted.
        schema_diff: Human-readable diff between current and target schemas.
    """

    current_schema_sql: str = Field(
        default="",
        description="Current database schema DDL from sqlite_master"
    )
    target_schema_sql: str = Field(
        default="",
        description="Target database schema DDL the agent must achieve"
    )
    last_execution_result: str = Field(
        default="",
        description="Result of the last SQL execution or error string"
    )
    step_number: int = Field(
        default=0,
        description="Current step count in the episode"
    )
    migration_progress: float = Field(
        default=0.0,
        ge=0.0,
        le=1.0,
        description="Current migration progress score from 0.0 to 1.0"
    )
    task_name: str = Field(
        default="",
        description="Name of the current task"
    )
    schema_diff: Optional[str] = Field(
        default=None,
        description="Human-readable diff between current and expected target schemas"
    )
    erd_visualization: Optional[str] = Field(
        default=None,
        description="Mermaid.js erDiagram representation of the current database structure"
    )


class MigrationState(State):
    """
    State for the SQL Migration Environment.

    Returned by the state() method. Contains episode metadata.

    Inherits from State:
        episode_id: str — Unique episode identifier
        step_count: int — Number of steps taken

    Attributes:
        task_name: Name of the current task.
        migration_progress: Current grader score.
        max_steps: Maximum steps allowed per episode.
    """

    task_name: str = Field(
        default="column-restructure",
        description="Name of the current task"
    )
    migration_progress: float = Field(
        default=0.0,
        description="Current migration progress score"
    )
    max_steps: int = Field(
        default=15,
        description="Maximum steps allowed per episode"
    )