NEWHACKTHONSPACE / models.py
AKGW580's picture
first commit
89208c7
"""
models.py β€” Typed Pydantic models for SQL Repair Clinic OpenEnv environment.
All action, observation, reward, and state models are defined here.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ─────────────────────────────────────────────
# Action
# ─────────────────────────────────────────────
class SQLAction(BaseModel):
"""The single action an agent can take: submit a SQL query."""
query: str = Field(..., description="A SQL query string to execute against the environment database.")
# ─────────────────────────────────────────────
# Observation
# ─────────────────────────────────────────────
class SQLObservation(BaseModel):
"""Full observation returned after reset() or step()."""
task_name: str = Field(..., description="Identifier of the active task.")
difficulty: str = Field(..., description="easy | medium | hard")
task_description: str = Field(..., description="Natural-language description of what the agent must achieve.")
schema_info: str = Field(..., description="DDL + sample rows describing the database schema.")
initial_broken_query: str = Field(..., description="The broken/incomplete SQL query the agent starts with.")
last_submitted_query: str = Field(..., description="Most recently submitted query (same as initial on reset).")
error_message: Optional[str] = Field(None, description="Execution error from the last submitted query, if any.")
result_preview: Optional[List[Dict[str, Any]]] = Field(
None, description="Up to 5 rows returned by the last query (None if query errored)."
)
step_count: int = Field(..., description="Number of steps taken so far in this episode.")
max_steps: int = Field(..., description="Maximum allowed steps before episode ends.")
last_reward: float = Field(..., description="Reward from the most recent step (0.0 on reset).")
hint: Optional[str] = Field(None, description="Optional hint shown after 3+ failed attempts.")
# ─────────────────────────────────────────────
# Reward
# ─────────────────────────────────────────────
class SQLReward(BaseModel):
"""Structured reward with explanation."""
value: float = Field(..., ge=0.0, le=1.0, description="Numeric reward in [0.0, 1.0].")
reason: str = Field(..., description="Human-readable explanation of why this reward was given.")
# ─────────────────────────────────────────────
# Step Response
# ─────────────────────────────────────────────
class StepResponse(BaseModel):
"""Complete response from POST /step."""
observation: SQLObservation
reward: float = Field(..., ge=0.0, le=1.0)
done: bool
info: Dict[str, Any]
# ─────────────────────────────────────────────
# State
# ─────────────────────────────────────────────
class EnvironmentState(BaseModel):
"""Lightweight state snapshot returned by GET /state."""
task_name: str
difficulty: str
step_count: int
max_steps: int
done: bool
last_reward: float
last_submitted_query: str
session_id: str
# ─────────────────────────────────────────────
# Reset Request
# ─────────────────────────────────────────────
class ResetRequest(BaseModel):
"""Optional body for POST /reset."""
task: str = Field(
default="fix_syntax",
description="Task to load: fix_syntax | fix_logic | write_analytical"
)
session_id: Optional[str] = Field(None, description="Optional session identifier.")