Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic models for the SQL Data Analyst OpenEnv environment. | |
| Defines Action, Observation, State, and StepResult typed models. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Literal, Optional | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Action | |
| # --------------------------------------------------------------------------- | |
| class SQLAction(BaseModel): | |
| """Action that an agent can take in the SQL Data Analyst environment.""" | |
| action_type: Literal[ | |
| "execute_query", # Run a SQL query against the episode database | |
| "describe_table", # Get schema + sample rows for a table | |
| "submit_answer", # Submit final answer to be graded | |
| "list_tables", # List all tables available in this episode | |
| "noop", # Do nothing (burn a step) | |
| ] = Field(description="Type of action to perform.") | |
| sql_query: Optional[str] = Field( | |
| default=None, | |
| description="SQL query string (required for execute_query and describe_table).", | |
| ) | |
| answer: Optional[Dict[str, Any]] = Field( | |
| default=None, | |
| description=( | |
| "Final answer dict submitted to the grader (required for submit_answer). " | |
| "Schema depends on the active task." | |
| ), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class SQLObservation(BaseModel): | |
| """Observation returned by the environment after each step.""" | |
| task_id: str = Field(description="Identifier of the active task.") | |
| goal: str = Field(description="Natural-language description of what the agent must accomplish.") | |
| schema_info: str = Field(description="DDL / schema description of the available tables.") | |
| data_sample: List[Dict[str, Any]] = Field( | |
| description="Up to 5 sample rows from the primary table, for orientation." | |
| ) | |
| last_query_result: Optional[List[Dict[str, Any]]] = Field( | |
| default=None, | |
| description="Rows returned by the most recent execute_query action (None if no query yet).", | |
| ) | |
| last_query_error: Optional[str] = Field( | |
| default=None, | |
| description="Error message if the last SQL query failed, otherwise None.", | |
| ) | |
| last_action_error: Optional[str] = Field( | |
| default=None, | |
| description="Error from the last action (malformed action, etc.), otherwise None.", | |
| ) | |
| step_count: int = Field(description="Number of steps taken so far in this episode.") | |
| max_steps: int = Field(description="Maximum steps allowed before episode terminates.") | |
| hints: Optional[List[str]] = Field( | |
| default=None, | |
| description="Optional hints unlocked as steps progress.", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # State | |
| # --------------------------------------------------------------------------- | |
| class SQLState(BaseModel): | |
| """Episode-level state metadata.""" | |
| episode_id: str = Field(description="Unique identifier for this episode.") | |
| task_id: str = Field(description="Active task identifier.") | |
| step_count: int = Field(description="Number of steps taken so far.") | |
| current_score: float = Field(description="Running score in [0.0, 1.0].") | |
| max_steps: int = Field(description="Maximum steps for this episode.") | |
| done: bool = Field(description="Whether the episode has ended.") | |
| # --------------------------------------------------------------------------- | |
| # StepResult (returned by /step endpoint) | |
| # --------------------------------------------------------------------------- | |
| class StepResult(BaseModel): | |
| """Full result returned by the /step endpoint.""" | |
| observation: SQLObservation | |
| reward: float = Field(description="Reward for the current step.") | |
| done: bool = Field(description="True if the episode has ended.") | |
| info: Dict[str, Any] = Field(default_factory=dict, description="Extra diagnostic info.") | |
| # --------------------------------------------------------------------------- | |
| # ResetResult (returned by /reset endpoint) | |
| # --------------------------------------------------------------------------- | |
| class ResetResult(BaseModel): | |
| """Result returned by the /reset endpoint.""" | |
| observation: SQLObservation | |
| done: bool = False | |
| info: Dict[str, Any] = Field(default_factory=dict) | |