Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Optional, Literal, Union, Annotated, Any | |
| from pydantic import BaseModel, Field, ConfigDict | |
| class QueryAction(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| action_type: Literal["query"] = Field(..., description="The type of action, must be 'query'") | |
| sql: str = Field(..., description="The SQL query to execute") | |
| class DDLAction(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| action_type: Literal["ddl"] = Field(..., description="The type of action, must be 'ddl'") | |
| sql: str = Field(..., description="The DDL SQL statement to execute") | |
| class TestAction(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| action_type: Literal["test"] = Field(..., description="The type of action, must be 'test'") | |
| target_table: str = Field(..., description="The target table to run tests against") | |
| class SubmitAction(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| action_type: Literal["submit"] = Field(..., description="The type of action, must be 'submit'") | |
| Action = Annotated[Union[QueryAction, DDLAction, TestAction, SubmitAction], Field(discriminator='action_type', description="Union of all four actions, discriminated by action_type")] | |
| class Observation(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| current_step: int = Field(..., description="The current step in the episode") | |
| max_steps: int = Field(..., description="The maximum number of steps allowed in the episode") | |
| task_id: int = Field(..., description="The unique identifier for the current task") | |
| task_description: str = Field(..., description="The description of the task") | |
| last_action_status: Literal["SUCCESS", "ERROR", "NONE"] = Field(..., description="The status of the last executed action") | |
| last_error_message: Optional[str] = Field(None, description="The error message from the last action, if any") | |
| query_results: List[Dict[str, Any]] = Field(default_factory=list, description="Up to 10 rows from the last query result") | |
| results_truncated: bool = Field(default=False, description="True if query returned more rows than shown") | |
| total_rows_returned: int = Field(default=0, description="Actual row count before truncation") | |
| schema_info: Dict[str, Any] = Field(..., description="Column names and types only — not data") | |
| system_logs: List[str] = Field(..., max_length=20, description="A list of system logs") | |
| logs_truncated: bool = Field(default=False, description="True if there were more logs than shown") | |
| progress_hint: Optional[str] = Field(None, description="A hint for the progress of the task, if available") | |
| class Reward(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| step_reward: float = Field(..., ge=-1.0, le=1.0, description="The reward for the current step, between -1.0 and 1.0") | |
| cumulative_reward: float = Field(..., description="The total cumulative reward accumulated so far") | |
| reward_breakdown: Dict[str, float] = Field(..., description="A breakdown of the components contributing to the reward") | |
| done: bool = Field(..., description="Whether the episode has finished successfully or not") | |
| truncated: bool = Field(..., description="Whether the episode was truncated (e.g., maximum steps reached)") | |
| grader_score_before: float = Field(..., description="Grader score before the action") | |
| grader_score_after: float = Field(..., description="Grader score after the action") | |
| class StateSnapshot(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| episode_id: str = Field(..., description="The unique identifier of the episode") | |
| task_id: int = Field(..., description="The unique identifier of the task") | |
| current_step: int = Field(..., description="The current step count") | |
| tables: Dict[str, List[Dict[str, Any]]] = Field(..., description="The contents of the tables currently in the environment") | |
| trajectory: List[Dict[str, Any]] = Field(..., description="The trajectory of actions and observations collected so far") | |
| grader_score: float = Field(..., description="The current score assigned by the grader") | |
| seed: int = Field(..., description="The random seed used for the current state snapshot") | |
| difficulty_multiplier: float = Field(1.0, description="Task difficulty curriculum multiplier") | |
| class CompletedEpisode(BaseModel): | |
| model_config = ConfigDict(strict=True) | |
| episode_id: str | |
| task_id: int | |
| total_steps: int | |
| final_score: float | |
| failed_actions: List[str] | |