SqlBenchmarking / models.py
hitanshjain1812's picture
Upload folder using huggingface_hub
7a0f237 verified
"""
Pydantic models for the SQL Data Analyst OpenEnv environment.
Defines Action, Observation, State, and StepResult typed models.
"""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class SQLAction(BaseModel):
"""Action that an agent can take in the SQL Data Analyst environment."""
action_type: Literal[
"execute_query", # Run a SQL query against the episode database
"describe_table", # Get schema + sample rows for a table
"submit_answer", # Submit final answer to be graded
"list_tables", # List all tables available in this episode
"noop", # Do nothing (burn a step)
] = Field(description="Type of action to perform.")
sql_query: Optional[str] = Field(
default=None,
description="SQL query string (required for execute_query and describe_table).",
)
answer: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Final answer dict submitted to the grader (required for submit_answer). "
"Schema depends on the active task."
),
)
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class SQLObservation(BaseModel):
"""Observation returned by the environment after each step."""
task_id: str = Field(description="Identifier of the active task.")
goal: str = Field(description="Natural-language description of what the agent must accomplish.")
schema_info: str = Field(description="DDL / schema description of the available tables.")
data_sample: List[Dict[str, Any]] = Field(
description="Up to 5 sample rows from the primary table, for orientation."
)
last_query_result: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Rows returned by the most recent execute_query action (None if no query yet).",
)
last_query_error: Optional[str] = Field(
default=None,
description="Error message if the last SQL query failed, otherwise None.",
)
last_action_error: Optional[str] = Field(
default=None,
description="Error from the last action (malformed action, etc.), otherwise None.",
)
step_count: int = Field(description="Number of steps taken so far in this episode.")
max_steps: int = Field(description="Maximum steps allowed before episode terminates.")
hints: Optional[List[str]] = Field(
default=None,
description="Optional hints unlocked as steps progress.",
)
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class SQLState(BaseModel):
"""Episode-level state metadata."""
episode_id: str = Field(description="Unique identifier for this episode.")
task_id: str = Field(description="Active task identifier.")
step_count: int = Field(description="Number of steps taken so far.")
current_score: float = Field(description="Running score in [0.0, 1.0].")
max_steps: int = Field(description="Maximum steps for this episode.")
done: bool = Field(description="Whether the episode has ended.")
# ---------------------------------------------------------------------------
# StepResult (returned by /step endpoint)
# ---------------------------------------------------------------------------
class StepResult(BaseModel):
"""Full result returned by the /step endpoint."""
observation: SQLObservation
reward: float = Field(description="Reward for the current step.")
done: bool = Field(description="True if the episode has ended.")
info: Dict[str, Any] = Field(default_factory=dict, description="Extra diagnostic info.")
# ---------------------------------------------------------------------------
# ResetResult (returned by /reset endpoint)
# ---------------------------------------------------------------------------
class ResetResult(BaseModel):
"""Result returned by the /reset endpoint."""
observation: SQLObservation
done: bool = False
info: Dict[str, Any] = Field(default_factory=dict)