""" models.py — Typed Pydantic models for SQL Repair Clinic OpenEnv environment. All action, observation, reward, and state models are defined here. """ from __future__ import annotations from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field # ───────────────────────────────────────────── # Action # ───────────────────────────────────────────── class SQLAction(BaseModel): """The single action an agent can take: submit a SQL query.""" query: str = Field(..., description="A SQL query string to execute against the environment database.") # ───────────────────────────────────────────── # Observation # ───────────────────────────────────────────── class SQLObservation(BaseModel): """Full observation returned after reset() or step().""" task_name: str = Field(..., description="Identifier of the active task.") difficulty: str = Field(..., description="easy | medium | hard") task_description: str = Field(..., description="Natural-language description of what the agent must achieve.") schema_info: str = Field(..., description="DDL + sample rows describing the database schema.") initial_broken_query: str = Field(..., description="The broken/incomplete SQL query the agent starts with.") last_submitted_query: str = Field(..., description="Most recently submitted query (same as initial on reset).") error_message: Optional[str] = Field(None, description="Execution error from the last submitted query, if any.") result_preview: Optional[List[Dict[str, Any]]] = Field( None, description="Up to 5 rows returned by the last query (None if query errored)." ) step_count: int = Field(..., description="Number of steps taken so far in this episode.") max_steps: int = Field(..., description="Maximum allowed steps before episode ends.") last_reward: float = Field(..., description="Reward from the most recent step (0.0 on reset).") hint: Optional[str] = Field(None, description="Optional hint shown after 3+ failed attempts.") # ───────────────────────────────────────────── # Reward # ───────────────────────────────────────────── class SQLReward(BaseModel): """Structured reward with explanation.""" value: float = Field(..., ge=0.0, le=1.0, description="Numeric reward in [0.0, 1.0].") reason: str = Field(..., description="Human-readable explanation of why this reward was given.") # ───────────────────────────────────────────── # Step Response # ───────────────────────────────────────────── class StepResponse(BaseModel): """Complete response from POST /step.""" observation: SQLObservation reward: float = Field(..., ge=0.0, le=1.0) done: bool info: Dict[str, Any] # ───────────────────────────────────────────── # State # ───────────────────────────────────────────── class EnvironmentState(BaseModel): """Lightweight state snapshot returned by GET /state.""" task_name: str difficulty: str step_count: int max_steps: int done: bool last_reward: float last_submitted_query: str session_id: str # ───────────────────────────────────────────── # Reset Request # ───────────────────────────────────────────── class ResetRequest(BaseModel): """Optional body for POST /reset.""" task: str = Field( default="fix_syntax", description="Task to load: fix_syntax | fix_logic | write_analytical" ) session_id: Optional[str] = Field(None, description="Optional session identifier.")