File size: 4,467 Bytes
7a0f237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Pydantic models for the SQL Data Analyst OpenEnv environment.
Defines Action, Observation, State, and StepResult typed models.
"""

from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field


# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------

class SQLAction(BaseModel):
    """Action that an agent can take in the SQL Data Analyst environment."""

    action_type: Literal[
        "execute_query",   # Run a SQL query against the episode database
        "describe_table",  # Get schema + sample rows for a table
        "submit_answer",   # Submit final answer to be graded
        "list_tables",     # List all tables available in this episode
        "noop",            # Do nothing (burn a step)
    ] = Field(description="Type of action to perform.")

    sql_query: Optional[str] = Field(
        default=None,
        description="SQL query string (required for execute_query and describe_table).",
    )

    answer: Optional[Dict[str, Any]] = Field(
        default=None,
        description=(
            "Final answer dict submitted to the grader (required for submit_answer). "
            "Schema depends on the active task."
        ),
    )


# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------

class SQLObservation(BaseModel):
    """Observation returned by the environment after each step."""

    task_id: str = Field(description="Identifier of the active task.")
    goal: str = Field(description="Natural-language description of what the agent must accomplish.")
    schema_info: str = Field(description="DDL / schema description of the available tables.")
    data_sample: List[Dict[str, Any]] = Field(
        description="Up to 5 sample rows from the primary table, for orientation."
    )
    last_query_result: Optional[List[Dict[str, Any]]] = Field(
        default=None,
        description="Rows returned by the most recent execute_query action (None if no query yet).",
    )
    last_query_error: Optional[str] = Field(
        default=None,
        description="Error message if the last SQL query failed, otherwise None.",
    )
    last_action_error: Optional[str] = Field(
        default=None,
        description="Error from the last action (malformed action, etc.), otherwise None.",
    )
    step_count: int = Field(description="Number of steps taken so far in this episode.")
    max_steps: int = Field(description="Maximum steps allowed before episode terminates.")
    hints: Optional[List[str]] = Field(
        default=None,
        description="Optional hints unlocked as steps progress.",
    )


# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------

class SQLState(BaseModel):
    """Episode-level state metadata."""

    episode_id: str = Field(description="Unique identifier for this episode.")
    task_id: str = Field(description="Active task identifier.")
    step_count: int = Field(description="Number of steps taken so far.")
    current_score: float = Field(description="Running score in [0.0, 1.0].")
    max_steps: int = Field(description="Maximum steps for this episode.")
    done: bool = Field(description="Whether the episode has ended.")


# ---------------------------------------------------------------------------
# StepResult  (returned by /step endpoint)
# ---------------------------------------------------------------------------

class StepResult(BaseModel):
    """Full result returned by the /step endpoint."""

    observation: SQLObservation
    reward: float = Field(description="Reward for the current step.")
    done: bool = Field(description="True if the episode has ended.")
    info: Dict[str, Any] = Field(default_factory=dict, description="Extra diagnostic info.")


# ---------------------------------------------------------------------------
# ResetResult  (returned by /reset endpoint)
# ---------------------------------------------------------------------------

class ResetResult(BaseModel):
    """Result returned by the /reset endpoint."""

    observation: SQLObservation
    done: bool = False
    info: Dict[str, Any] = Field(default_factory=dict)