File size: 4,470 Bytes
22328de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import List, Dict, Optional, Literal, Union, Annotated, Any
from pydantic import BaseModel, Field, ConfigDict

class QueryAction(BaseModel):
    model_config = ConfigDict(strict=True)
    action_type: Literal["query"] = Field(..., description="The type of action, must be 'query'")
    sql: str = Field(..., description="The SQL query to execute")

class DDLAction(BaseModel):
    model_config = ConfigDict(strict=True)
    action_type: Literal["ddl"] = Field(..., description="The type of action, must be 'ddl'")
    sql: str = Field(..., description="The DDL SQL statement to execute")

class TestAction(BaseModel):
    model_config = ConfigDict(strict=True)
    action_type: Literal["test"] = Field(..., description="The type of action, must be 'test'")
    target_table: str = Field(..., description="The target table to run tests against")

class SubmitAction(BaseModel):
    model_config = ConfigDict(strict=True)
    action_type: Literal["submit"] = Field(..., description="The type of action, must be 'submit'")

Action = Annotated[Union[QueryAction, DDLAction, TestAction, SubmitAction], Field(discriminator='action_type', description="Union of all four actions, discriminated by action_type")]

class Observation(BaseModel):
    model_config = ConfigDict(strict=True)
    current_step: int = Field(..., description="The current step in the episode")
    max_steps: int = Field(..., description="The maximum number of steps allowed in the episode")
    task_id: int = Field(..., description="The unique identifier for the current task")
    task_description: str = Field(..., description="The description of the task")
    last_action_status: Literal["SUCCESS", "ERROR", "NONE"] = Field(..., description="The status of the last executed action")
    last_error_message: Optional[str] = Field(None, description="The error message from the last action, if any")
    query_results: List[Dict[str, Any]] = Field(default_factory=list, description="Up to 10 rows from the last query result")
    results_truncated: bool = Field(default=False, description="True if query returned more rows than shown")
    total_rows_returned: int = Field(default=0, description="Actual row count before truncation")
    schema_info: Dict[str, Any] = Field(..., description="Column names and types only — not data")
    system_logs: List[str] = Field(..., max_length=20, description="A list of system logs")
    logs_truncated: bool = Field(default=False, description="True if there were more logs than shown")
    progress_hint: Optional[str] = Field(None, description="A hint for the progress of the task, if available")

class Reward(BaseModel):
    model_config = ConfigDict(strict=True)
    step_reward: float = Field(..., ge=-1.0, le=1.0, description="The reward for the current step, between -1.0 and 1.0")
    cumulative_reward: float = Field(..., description="The total cumulative reward accumulated so far")
    reward_breakdown: Dict[str, float] = Field(..., description="A breakdown of the components contributing to the reward")
    done: bool = Field(..., description="Whether the episode has finished successfully or not")
    truncated: bool = Field(..., description="Whether the episode was truncated (e.g., maximum steps reached)")
    grader_score_before: float = Field(..., description="Grader score before the action")
    grader_score_after: float = Field(..., description="Grader score after the action")

class StateSnapshot(BaseModel):
    model_config = ConfigDict(strict=True)
    episode_id: str = Field(..., description="The unique identifier of the episode")
    task_id: int = Field(..., description="The unique identifier of the task")
    current_step: int = Field(..., description="The current step count")
    tables: Dict[str, List[Dict[str, Any]]] = Field(..., description="The contents of the tables currently in the environment")
    trajectory: List[Dict[str, Any]] = Field(..., description="The trajectory of actions and observations collected so far")
    grader_score: float = Field(..., description="The current score assigned by the grader")
    seed: int = Field(..., description="The random seed used for the current state snapshot")
    difficulty_multiplier: float = Field(1.0, description="Task difficulty curriculum multiplier")

class CompletedEpisode(BaseModel):
    model_config = ConfigDict(strict=True)
    episode_id: str
    task_id: int
    total_steps: int
    final_score: float
    failed_actions: List[str]