Spaces:
Sleeping
Sleeping
File size: 4,470 Bytes
22328de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | from typing import List, Dict, Optional, Literal, Union, Annotated, Any
from pydantic import BaseModel, Field, ConfigDict
class QueryAction(BaseModel):
model_config = ConfigDict(strict=True)
action_type: Literal["query"] = Field(..., description="The type of action, must be 'query'")
sql: str = Field(..., description="The SQL query to execute")
class DDLAction(BaseModel):
model_config = ConfigDict(strict=True)
action_type: Literal["ddl"] = Field(..., description="The type of action, must be 'ddl'")
sql: str = Field(..., description="The DDL SQL statement to execute")
class TestAction(BaseModel):
model_config = ConfigDict(strict=True)
action_type: Literal["test"] = Field(..., description="The type of action, must be 'test'")
target_table: str = Field(..., description="The target table to run tests against")
class SubmitAction(BaseModel):
model_config = ConfigDict(strict=True)
action_type: Literal["submit"] = Field(..., description="The type of action, must be 'submit'")
Action = Annotated[Union[QueryAction, DDLAction, TestAction, SubmitAction], Field(discriminator='action_type', description="Union of all four actions, discriminated by action_type")]
class Observation(BaseModel):
model_config = ConfigDict(strict=True)
current_step: int = Field(..., description="The current step in the episode")
max_steps: int = Field(..., description="The maximum number of steps allowed in the episode")
task_id: int = Field(..., description="The unique identifier for the current task")
task_description: str = Field(..., description="The description of the task")
last_action_status: Literal["SUCCESS", "ERROR", "NONE"] = Field(..., description="The status of the last executed action")
last_error_message: Optional[str] = Field(None, description="The error message from the last action, if any")
query_results: List[Dict[str, Any]] = Field(default_factory=list, description="Up to 10 rows from the last query result")
results_truncated: bool = Field(default=False, description="True if query returned more rows than shown")
total_rows_returned: int = Field(default=0, description="Actual row count before truncation")
schema_info: Dict[str, Any] = Field(..., description="Column names and types only — not data")
system_logs: List[str] = Field(..., max_length=20, description="A list of system logs")
logs_truncated: bool = Field(default=False, description="True if there were more logs than shown")
progress_hint: Optional[str] = Field(None, description="A hint for the progress of the task, if available")
class Reward(BaseModel):
model_config = ConfigDict(strict=True)
step_reward: float = Field(..., ge=-1.0, le=1.0, description="The reward for the current step, between -1.0 and 1.0")
cumulative_reward: float = Field(..., description="The total cumulative reward accumulated so far")
reward_breakdown: Dict[str, float] = Field(..., description="A breakdown of the components contributing to the reward")
done: bool = Field(..., description="Whether the episode has finished successfully or not")
truncated: bool = Field(..., description="Whether the episode was truncated (e.g., maximum steps reached)")
grader_score_before: float = Field(..., description="Grader score before the action")
grader_score_after: float = Field(..., description="Grader score after the action")
class StateSnapshot(BaseModel):
model_config = ConfigDict(strict=True)
episode_id: str = Field(..., description="The unique identifier of the episode")
task_id: int = Field(..., description="The unique identifier of the task")
current_step: int = Field(..., description="The current step count")
tables: Dict[str, List[Dict[str, Any]]] = Field(..., description="The contents of the tables currently in the environment")
trajectory: List[Dict[str, Any]] = Field(..., description="The trajectory of actions and observations collected so far")
grader_score: float = Field(..., description="The current score assigned by the grader")
seed: int = Field(..., description="The random seed used for the current state snapshot")
difficulty_multiplier: float = Field(1.0, description="Task difficulty curriculum multiplier")
class CompletedEpisode(BaseModel):
model_config = ConfigDict(strict=True)
episode_id: str
task_id: int
total_steps: int
final_score: float
failed_actions: List[str]
|