Spaces:
Sleeping
Sleeping
| """ | |
| models.py β Typed Pydantic models for SQL Repair Clinic OpenEnv environment. | |
| All action, observation, reward, and state models are defined here. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Action | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class SQLAction(BaseModel): | |
| """The single action an agent can take: submit a SQL query.""" | |
| query: str = Field(..., description="A SQL query string to execute against the environment database.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Observation | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class SQLObservation(BaseModel): | |
| """Full observation returned after reset() or step().""" | |
| task_name: str = Field(..., description="Identifier of the active task.") | |
| difficulty: str = Field(..., description="easy | medium | hard") | |
| task_description: str = Field(..., description="Natural-language description of what the agent must achieve.") | |
| schema_info: str = Field(..., description="DDL + sample rows describing the database schema.") | |
| initial_broken_query: str = Field(..., description="The broken/incomplete SQL query the agent starts with.") | |
| last_submitted_query: str = Field(..., description="Most recently submitted query (same as initial on reset).") | |
| error_message: Optional[str] = Field(None, description="Execution error from the last submitted query, if any.") | |
| result_preview: Optional[List[Dict[str, Any]]] = Field( | |
| None, description="Up to 5 rows returned by the last query (None if query errored)." | |
| ) | |
| step_count: int = Field(..., description="Number of steps taken so far in this episode.") | |
| max_steps: int = Field(..., description="Maximum allowed steps before episode ends.") | |
| last_reward: float = Field(..., description="Reward from the most recent step (0.0 on reset).") | |
| hint: Optional[str] = Field(None, description="Optional hint shown after 3+ failed attempts.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reward | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class SQLReward(BaseModel): | |
| """Structured reward with explanation.""" | |
| value: float = Field(..., ge=0.0, le=1.0, description="Numeric reward in [0.0, 1.0].") | |
| reason: str = Field(..., description="Human-readable explanation of why this reward was given.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Step Response | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class StepResponse(BaseModel): | |
| """Complete response from POST /step.""" | |
| observation: SQLObservation | |
| reward: float = Field(..., ge=0.0, le=1.0) | |
| done: bool | |
| info: Dict[str, Any] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # State | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class EnvironmentState(BaseModel): | |
| """Lightweight state snapshot returned by GET /state.""" | |
| task_name: str | |
| difficulty: str | |
| step_count: int | |
| max_steps: int | |
| done: bool | |
| last_reward: float | |
| last_submitted_query: str | |
| session_id: str | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reset Request | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| """Optional body for POST /reset.""" | |
| task: str = Field( | |
| default="fix_syntax", | |
| description="Task to load: fix_syntax | fix_logic | write_analytical" | |
| ) | |
| session_id: Optional[str] = Field(None, description="Optional session identifier.") | |