Spaces:
Running
Running
File size: 5,155 Bytes
30cf758 9b71d1b 30cf758 9b71d1b 30cf758 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Typed Pydantic models for the SQL Debug Environment.
Implements the OpenEnv spec: Observation, Action, Reward.
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from enum import Enum
class ActionType(str, Enum):
SUBMIT_QUERY = "submit_query" # Submit a fixed SQL query for evaluation
INSPECT_SCHEMA = "inspect_schema" # Request schema info (costs 0 reward, gives info)
INSPECT_ERROR = "inspect_error" # Request error details (costs 0, gives stack trace)
INSPECT_SAMPLE = "inspect_sample" # Request 3 sample rows from a table
RESET_QUERY = "reset_query" # Reset to the original broken query (costs -0.05 penalty)
class SQLDebugAction(BaseModel):
"""
Action model for the SQL Debug Environment.
The agent can either:
- submit_query: Submit a fixed SQL string for evaluation
- inspect_schema: Get table schema info (free action, no reward change)
- inspect_error: Get detailed error message from last query run
- inspect_sample: Get sample rows from a specified table
- reset_query: Go back to original broken query (costs -0.05 penalty)
"""
action_type: ActionType = Field(
description="Type of action to take"
)
query: Optional[str] = Field(
default=None,
description="SQL query string. Required when action_type is 'submit_query'."
)
table_name: Optional[str] = Field(
default=None,
description="Table name. Required when action_type is 'inspect_sample'."
)
class Config:
json_schema_extra = {
"example": {
"action_type": "submit_query",
"query": "SELECT u.name, COUNT(o.id) as order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name ORDER BY order_count DESC"
}
}
class QueryResult(BaseModel):
"""Result of executing a SQL query."""
success: bool
rows: Optional[List[Dict[str, Any]]] = None
row_count: Optional[int] = None
error_message: Optional[str] = None
execution_time_ms: Optional[float] = None
class SchemaInfo(BaseModel):
"""Database schema information."""
tables: Dict[str, List[Dict[str, str]]] # table_name -> list of {name, type, nullable}
sample_data: Optional[Dict[str, List[Dict[str, Any]]]] = None
class SQLDebugObservation(BaseModel):
"""
Observation returned after each step.
Contains the current state of the debugging session:
- The original broken query (always visible)
- The agent's current best query
- Result of last action
- Progress indicators
- Schema/error info if requested
"""
task_id: str = Field(description="Current task identifier")
task_description: str = Field(description="Natural language description of the bug to fix")
original_query: str = Field(description="The original broken SQL query")
current_query: Optional[str] = Field(default=None, description="Agent's last submitted query")
expected_description: str = Field(description="Description of what the correct output should look like")
# Last action result
last_action_type: str
last_query_result: Optional[QueryResult] = None
# Progress
steps_taken: int
steps_remaining: int
current_score: float = Field(description="Current score in strict range (0, 1) for this episode")
# Contextual help (populated based on action type)
schema_info: Optional[SchemaInfo] = None
error_details: Optional[str] = None
sample_rows: Optional[List[Dict[str, Any]]] = None
# Hints (unlocked after step 3 on easy, step 5 on medium/hard)
hint: Optional[str] = None
# Episode status
is_done: bool = False
success: bool = False
class SQLDebugReward(BaseModel):
"""
Reward signal for the SQL Debug Environment.
Reward components (all sum to final reward):
- correctness: 0.0-0.6 based on row-level match vs expected output
- efficiency: 0.0-0.2 bonus for solving in fewer steps
- syntax_progress: 0.0-0.1 for getting a syntactically valid query (even if wrong)
- schema_bonus: 0.0-0.1 for queries that reference correct tables/columns
- penalties: negative values for reset_query, infinite loops, destructive SQL
"""
value: float = Field(ge=0.001, le=0.999, description="Total reward for this step")
correctness: float = Field(ge=0.0, le=0.6)
efficiency: float = Field(ge=0.0, le=0.2)
syntax_progress: float = Field(ge=0.0, le=0.1)
schema_bonus: float = Field(ge=0.0, le=0.1)
penalty: float = Field(ge=0.0, le=0.2, description="Penalty deduction magnitude (non-negative)")
breakdown: str = Field(description="Human-readable reward breakdown")
class EpisodeState(BaseModel):
"""Full internal state of an episode. Used by state() endpoint."""
task_id: str
task_difficulty: str
original_query: str
current_query: Optional[str]
best_score_so_far: float
steps_taken: int
max_steps: int
action_history: List[Dict[str, Any]]
reward_history: List[float]
is_done: bool
success: bool
db_schema: Dict[str, Any]
|