Spaces:
Running on Zero
Running on Zero
| """ | |
| SQLEnv Pydantic models β the typed contract between the environment and agent. | |
| These models define the typed interface for the SQLEnv RL environment: | |
| Action β what the agent sends each step | |
| Observation β what the agent receives back | |
| State β episode metadata (lightweight logging/debugging view) | |
| RL terminology β state vs observation | |
| βββββββββββββββββββββββββββββββββββββ | |
| In RL theory: | |
| State (s) A COMPLETE description of the world. Nothing is hidden. | |
| Observation (o) A PARTIAL description of a state, which may omit info. | |
| In SQLEnv these map to: | |
| EpisodeContext The full RL state (s). Lives on the server only. | |
| Contains gold answers, reward accumulators, DB | |
| connection, full query history β everything needed | |
| to advance the simulation and compute rewards. | |
| SQLObservation The observation (o). Sent to the agent over the wire. | |
| Contains the question, truncated results, revealed | |
| schema, budget, and action history. The agent NEVER | |
| sees the gold answer, progress scores, or full DB. | |
| SQLState Lightweight episode metadata (episode_id, step_count). | |
| This is NOT the RL state; it is a convenience for | |
| logging/debugging. | |
| This separation is what makes SQLEnv a POMDP: the agent must act under | |
| uncertainty, which is what makes exploration necessary and learnable. | |
| """ | |
| import sqlite3 | |
| from dataclasses import dataclass, field as dataclass_field | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Wire types: the typed contract between the environment and the agent. | |
| # | |
| # These were originally OpenEnv Action/Observation/State subclasses. OpenEnv | |
| # has been removed (training runs the environment in-process via TRL), so they | |
| # are now plain Pydantic models that re-declare the few fields the base classes | |
| # used to provide: done/reward on the observation, episode_id/step_count on the | |
| # state. | |
| # --------------------------------------------------------------------------- | |
| class SQLAction(BaseModel): | |
| """What the agent sends each step. | |
| The action space is intentionally small and structured so agents can | |
| explicitly control the environment loop. | |
| """ | |
| action_type: str = Field( | |
| ..., | |
| description="One of: DESCRIBE, SAMPLE, QUERY, ANSWER", | |
| ) | |
| argument: str = Field( | |
| ..., | |
| description=( | |
| "Table name (DESCRIBE/SAMPLE), SQL string (QUERY), " | |
| "or answer value (ANSWER)." | |
| ), | |
| ) | |
| class SQLObservation(BaseModel): | |
| """What the agent receives after each step. | |
| This is the agent's PARTIAL view of the world. Key design choices: | |
| - schema_info starts with table names only; columns are revealed | |
| incrementally as the agent DESCRIBEs tables. | |
| - result is always a truncated string, never raw data. The agent sees | |
| what a human analyst would see in a terminal β at most N rows of | |
| formatted text. This keeps the observation bounded and forces the | |
| agent to reason about what it sees rather than brute-force scanning. | |
| - action_history gives the agent memory of its own trajectory without | |
| the server needing to re-send full results from prior steps. | |
| """ | |
| # Formerly inherited from OpenEnv's Observation base class: | |
| done: bool = Field(default=False, description="Whether the episode has ended") | |
| reward: float | None = Field( | |
| default=None, description="Reward for the last step (None if not scored)" | |
| ) | |
| question: str = Field(..., description="The NL question to answer") | |
| schema_info: str = Field(..., description="Known schema information") | |
| result: str = Field(default="", description="Result of the last action") | |
| error: str = Field(default="", description="Error message if action failed") | |
| step_count: int = Field(default=0, description="Current step number") | |
| budget_remaining: int = Field(default=0, description="Steps remaining") | |
| action_history: list[str] = Field( | |
| default_factory=list, | |
| description="Summary of previous actions", | |
| ) | |
| class SQLState(BaseModel): | |
| """Episode metadata β minimal public state for logging and debugging. | |
| This is NOT the full internal bookkeeping (see EpisodeContext below). | |
| """ | |
| # Formerly inherited from OpenEnv's State base class: | |
| episode_id: str | None = Field(default=None, description="Episode identifier") | |
| step_count: int = Field(default=0, description="Current step number") | |
| history_messages: list[dict[str, str]] = Field(default_factory=list) | |
| current_action_type: str = Field( | |
| default="QUERY", | |
| description="Current action type: DESCRIBE, SAMPLE, QUERY, or ANSWER", | |
| ) | |
| class QuestionRecord: | |
| """One question from the Spider dataset.""" | |
| question_id: str | |
| question_text: str | |
| database_name: str | |
| gold_sql: str | |
| gold_answer: str | |
| answer_type: str | |
| difficulty: str | |
| tables_involved: list[str] | |
| class EpisodeContext: | |
| """Per-episode server-side state (never sent to agent).""" | |
| episode_id: str | |
| db_connection: sqlite3.Connection | |
| question_record: QuestionRecord | |
| step_count: int = 0 | |
| budget: int = 15 | |
| described_tables: set[str] = dataclass_field(default_factory=set) | |
| action_log: list[str] = dataclass_field(default_factory=list) | |
| done: bool = False | |
| gold_answer: str | None = None | |
| gold_rows: list[tuple] = dataclass_field(default_factory=list) | |
| query_hashes: set[str] = dataclass_field(default_factory=set) | |
| previous_progress: float = 0.0 | |