Spaces:
Sleeping
Sleeping
Commit Β·
8cb206e
1
Parent(s): 1eef47f
Round 2: SQL Database Engineer Agent - 24/24 tests passing
Browse files- api/server.py +112 -58
- blog/mini_blog.md +0 -0
- dataset/easy_scenarios.json +92 -0
- dataset/hard_scenarios.json +185 -0
- dataset/medium_scenarios.json +137 -0
- env/__pycache__/models.cpython-312.pyc +0 -0
- env/environment.py +195 -130
- env/models.py +93 -58
- env/reward.py +203 -94
- env/tasks.py +178 -63
- tests/test_environment.py +3 -4
- tests/test_graders.py +2 -2
- training/evaluate_agent.py +0 -0
- training/generate_training_data.py +0 -0
- training/train_agent.py +0 -0
api/server.py
CHANGED
|
@@ -16,7 +16,7 @@ from env.models import (
|
|
| 16 |
StepResponse, ResetResponse, TaskListResponse,
|
| 17 |
BaselineResponse, BaselineResult,
|
| 18 |
GraderRequest, GraderResponse,
|
| 19 |
-
HealthResponse, TaskInfo
|
| 20 |
)
|
| 21 |
from env.tasks import task_manager, ACTION_SCHEMA
|
| 22 |
from env.graders import grade
|
|
@@ -33,18 +33,21 @@ async def lifespan(app: FastAPI):
|
|
| 33 |
environment.reset(difficulty="easy")
|
| 34 |
yield
|
| 35 |
|
|
|
|
| 36 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
# APP DEFINITION
|
| 38 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
|
| 40 |
app = FastAPI(
|
| 41 |
-
title = "SQL
|
| 42 |
description = (
|
| 43 |
"An OpenEnv-compliant reinforcement learning environment where AI agents "
|
| 44 |
-
"learn to
|
| 45 |
-
"
|
|
|
|
|
|
|
| 46 |
),
|
| 47 |
-
version = "
|
| 48 |
lifespan = lifespan,
|
| 49 |
docs_url = "/docs",
|
| 50 |
redoc_url = "/redoc",
|
|
@@ -72,12 +75,11 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|
| 72 |
|
| 73 |
|
| 74 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
-
# FAVICON
|
| 76 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
|
| 78 |
@app.get("/favicon.ico", include_in_schema=False)
|
| 79 |
async def favicon():
|
| 80 |
-
"""Returns 204 No Content instead of 404 for favicon requests."""
|
| 81 |
return Response(status_code=204)
|
| 82 |
|
| 83 |
|
|
@@ -87,10 +89,10 @@ async def favicon():
|
|
| 87 |
|
| 88 |
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
| 89 |
async def health():
|
| 90 |
-
"""Liveness check. Always returns 200.
|
| 91 |
return HealthResponse(
|
| 92 |
status = "ok",
|
| 93 |
-
version = "
|
| 94 |
uptime = round(time.time() - _startup_time, 2)
|
| 95 |
)
|
| 96 |
|
|
@@ -106,8 +108,8 @@ class ResetBody(BaseModel):
|
|
| 106 |
@app.post("/reset", response_model=Observation, tags=["Environment"])
|
| 107 |
async def reset(body: ResetBody = ResetBody()):
|
| 108 |
"""
|
| 109 |
-
Starts a fresh episode.
|
| 110 |
-
|
| 111 |
"""
|
| 112 |
try:
|
| 113 |
obs = environment.reset(
|
|
@@ -129,8 +131,9 @@ async def reset(body: ResetBody = ResetBody()):
|
|
| 129 |
async def step(action: Action):
|
| 130 |
"""
|
| 131 |
Submits an action to the environment.
|
| 132 |
-
|
| 133 |
-
|
|
|
|
| 134 |
"""
|
| 135 |
try:
|
| 136 |
response = environment.step(action)
|
|
@@ -140,8 +143,8 @@ async def step(action: Action):
|
|
| 140 |
return StepResponse(
|
| 141 |
observation = environment._build_observation(),
|
| 142 |
reward = Reward(
|
| 143 |
-
score =
|
| 144 |
-
breakdown = {"validation_error":
|
| 145 |
feedback = f"Malformed action: {str(e)}"
|
| 146 |
),
|
| 147 |
done = False,
|
|
@@ -157,11 +160,7 @@ async def step(action: Action):
|
|
| 157 |
|
| 158 |
@app.get("/state", response_model=EpisodeState, tags=["Environment"])
|
| 159 |
async def state():
|
| 160 |
-
"""
|
| 161 |
-
Returns full current environment state.
|
| 162 |
-
Works before reset() is called β returns default empty state.
|
| 163 |
-
Always JSON-serializable. Never crashes.
|
| 164 |
-
"""
|
| 165 |
return environment.state()
|
| 166 |
|
| 167 |
|
|
@@ -172,8 +171,8 @@ async def state():
|
|
| 172 |
@app.get("/tasks", response_model=TaskListResponse, tags=["Tasks"])
|
| 173 |
async def tasks():
|
| 174 |
"""
|
| 175 |
-
Lists all
|
| 176 |
-
|
| 177 |
"""
|
| 178 |
all_tasks = task_manager.list_all_tasks()
|
| 179 |
return TaskListResponse(
|
|
@@ -191,8 +190,8 @@ async def tasks():
|
|
| 191 |
async def grader(request: GraderRequest):
|
| 192 |
"""
|
| 193 |
Grades a completed episode action.
|
|
|
|
| 194 |
Returns float score strictly between 0.0 and 1.0 exclusive.
|
| 195 |
-
Never crashes.
|
| 196 |
"""
|
| 197 |
try:
|
| 198 |
if request.action is None:
|
|
@@ -201,14 +200,42 @@ async def grader(request: GraderRequest):
|
|
| 201 |
feedback = "No action provided for grading.",
|
| 202 |
breakdown = {"error": "null_action"}
|
| 203 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
score, breakdown, feedback = grade(request.action, request.task_id)
|
| 205 |
-
# Clamp strictly between 0 and 1 exclusive
|
| 206 |
score = max(0.001, min(0.999, score))
|
| 207 |
-
return GraderResponse(
|
| 208 |
-
|
| 209 |
-
feedback = feedback,
|
| 210 |
-
breakdown = breakdown
|
| 211 |
-
)
|
| 212 |
except Exception as e:
|
| 213 |
return GraderResponse(
|
| 214 |
score = 0.001,
|
|
@@ -216,6 +243,7 @@ async def grader(request: GraderRequest):
|
|
| 216 |
breakdown = {"error": str(e)}
|
| 217 |
)
|
| 218 |
|
|
|
|
| 219 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
# 7. /baseline β POST
|
| 221 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -223,9 +251,8 @@ async def grader(request: GraderRequest):
|
|
| 223 |
@app.post("/baseline", response_model=BaselineResponse, tags=["Baseline"])
|
| 224 |
async def baseline():
|
| 225 |
"""
|
| 226 |
-
Runs the baseline agent against all
|
| 227 |
-
|
| 228 |
-
Edge case: OPENAI_API_KEY not set β continues with rule-based agent.
|
| 229 |
"""
|
| 230 |
try:
|
| 231 |
import baseline as baseline_module
|
|
@@ -236,45 +263,72 @@ async def baseline():
|
|
| 236 |
return results
|
| 237 |
except asyncio.TimeoutError:
|
| 238 |
return BaselineResponse(
|
| 239 |
-
results=[
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
score = 0.0,
|
| 244 |
-
steps = 0,
|
| 245 |
-
feedback = "Baseline timed out after 55 seconds."
|
| 246 |
-
)
|
| 247 |
-
],
|
| 248 |
average_score=0.0
|
| 249 |
)
|
| 250 |
except Exception as e:
|
| 251 |
return BaselineResponse(
|
| 252 |
-
results=[
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
score = 0.0,
|
| 257 |
-
steps = 0,
|
| 258 |
-
feedback = f"Baseline error: {str(e)}"
|
| 259 |
-
)
|
| 260 |
-
],
|
| 261 |
average_score=0.0
|
| 262 |
)
|
| 263 |
|
| 264 |
|
| 265 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
|
| 269 |
@app.get("/", tags=["System"])
|
| 270 |
async def root():
|
| 271 |
return {
|
| 272 |
-
"name": "SQL
|
| 273 |
-
"version": "
|
|
|
|
| 274 |
"docs": "/docs",
|
| 275 |
"health": "/health",
|
| 276 |
-
"endpoints": ["/reset", "/step", "/state", "/tasks", "/grader", "/baseline", "/health"],
|
| 277 |
-
"hackathon": "META x PyTorch x SST OpenEnv Hackathon",
|
| 278 |
-
"domain": "
|
| 279 |
-
"tasks_count":
|
| 280 |
-
|
|
|
|
|
|
|
|
|
| 16 |
StepResponse, ResetResponse, TaskListResponse,
|
| 17 |
BaselineResponse, BaselineResult,
|
| 18 |
GraderRequest, GraderResponse,
|
| 19 |
+
HealthResponse, TaskInfo, ProgressResponse
|
| 20 |
)
|
| 21 |
from env.tasks import task_manager, ACTION_SCHEMA
|
| 22 |
from env.graders import grade
|
|
|
|
| 33 |
environment.reset(difficulty="easy")
|
| 34 |
yield
|
| 35 |
|
| 36 |
+
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# APP DEFINITION
|
| 39 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
|
| 41 |
app = FastAPI(
|
| 42 |
+
title = "SQL Database Engineer Agent β OpenEnv Environment",
|
| 43 |
description = (
|
| 44 |
"An OpenEnv-compliant reinforcement learning environment where AI agents "
|
| 45 |
+
"learn to act like senior database engineers. "
|
| 46 |
+
"The agent manages a simulated production database over 50+ steps: "
|
| 47 |
+
"inspecting slow queries, creating indexes, rewriting queries, partitioning tables. "
|
| 48 |
+
"Built for the META x PyTorch x SST OpenEnv Hackathon Finals β April 25-26, Bangalore."
|
| 49 |
),
|
| 50 |
+
version = "2.0.0",
|
| 51 |
lifespan = lifespan,
|
| 52 |
docs_url = "/docs",
|
| 53 |
redoc_url = "/redoc",
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
# FAVICON
|
| 79 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
|
| 81 |
@app.get("/favicon.ico", include_in_schema=False)
|
| 82 |
async def favicon():
|
|
|
|
| 83 |
return Response(status_code=204)
|
| 84 |
|
| 85 |
|
|
|
|
| 89 |
|
| 90 |
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
| 91 |
async def health():
|
| 92 |
+
"""Liveness check. Always returns 200."""
|
| 93 |
return HealthResponse(
|
| 94 |
status = "ok",
|
| 95 |
+
version = "2.0.0",
|
| 96 |
uptime = round(time.time() - _startup_time, 2)
|
| 97 |
)
|
| 98 |
|
|
|
|
| 108 |
@app.post("/reset", response_model=Observation, tags=["Environment"])
|
| 109 |
async def reset(body: ResetBody = ResetBody()):
|
| 110 |
"""
|
| 111 |
+
Starts a fresh episode. Initializes DatabaseSimulator.
|
| 112 |
+
Returns the initial Observation with DB state and slow queries.
|
| 113 |
"""
|
| 114 |
try:
|
| 115 |
obs = environment.reset(
|
|
|
|
| 131 |
async def step(action: Action):
|
| 132 |
"""
|
| 133 |
Submits an action to the environment.
|
| 134 |
+
Round 2 actions: inspect_query, create_index, rewrite_query,
|
| 135 |
+
partition_table, analyze_statistics, analyze_indexes, submit_report.
|
| 136 |
+
Returns (observation, reward, done, info) with DB performance delta.
|
| 137 |
"""
|
| 138 |
try:
|
| 139 |
response = environment.step(action)
|
|
|
|
| 143 |
return StepResponse(
|
| 144 |
observation = environment._build_observation(),
|
| 145 |
reward = Reward(
|
| 146 |
+
score = 0.001,
|
| 147 |
+
breakdown = {"validation_error": 0.001},
|
| 148 |
feedback = f"Malformed action: {str(e)}"
|
| 149 |
),
|
| 150 |
done = False,
|
|
|
|
| 160 |
|
| 161 |
@app.get("/state", response_model=EpisodeState, tags=["Environment"])
|
| 162 |
async def state():
|
| 163 |
+
"""Returns full current environment state including performance history."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
return environment.state()
|
| 165 |
|
| 166 |
|
|
|
|
| 171 |
@app.get("/tasks", response_model=TaskListResponse, tags=["Tasks"])
|
| 172 |
async def tasks():
|
| 173 |
"""
|
| 174 |
+
Lists all 30 tasks (15 Round 2 scenarios + 15 Round 1 cases).
|
| 175 |
+
Includes complete action schema for all 15 action types.
|
| 176 |
"""
|
| 177 |
all_tasks = task_manager.list_all_tasks()
|
| 178 |
return TaskListResponse(
|
|
|
|
| 190 |
async def grader(request: GraderRequest):
|
| 191 |
"""
|
| 192 |
Grades a completed episode action.
|
| 193 |
+
For Round 2 submit_report: computes score from DB performance improvement.
|
| 194 |
Returns float score strictly between 0.0 and 1.0 exclusive.
|
|
|
|
| 195 |
"""
|
| 196 |
try:
|
| 197 |
if request.action is None:
|
|
|
|
| 200 |
feedback = "No action provided for grading.",
|
| 201 |
breakdown = {"error": "null_action"}
|
| 202 |
)
|
| 203 |
+
|
| 204 |
+
# Round 2: submit_report grading uses DB state
|
| 205 |
+
if request.action.action_type == ActionType.SUBMIT_REPORT:
|
| 206 |
+
ep_state = environment.state()
|
| 207 |
+
perf_history = ep_state.action_counts.get("_perf_history", [0.0])
|
| 208 |
+
baseline = ep_state.action_counts.get("_baseline_score", 0.0)
|
| 209 |
+
best_score = ep_state.action_counts.get("_best_score", 0.0)
|
| 210 |
+
current = perf_history[-1] if perf_history else 0.0
|
| 211 |
+
max_possible = max(1.0, 100.0 - baseline)
|
| 212 |
+
|
| 213 |
+
perf_improvement = (current - baseline) / max_possible
|
| 214 |
+
step_efficiency = 1.0 - (ep_state.step_count / max(1, 50))
|
| 215 |
+
score = round(
|
| 216 |
+
(perf_improvement * 0.60) + (step_efficiency * 0.20) + 0.10, 4
|
| 217 |
+
)
|
| 218 |
+
score = max(0.001, min(0.999, score))
|
| 219 |
+
|
| 220 |
+
return GraderResponse(
|
| 221 |
+
score = score,
|
| 222 |
+
feedback = (
|
| 223 |
+
f"DB performance: {baseline:.1f} β {current:.1f} "
|
| 224 |
+
f"(best: {best_score:.1f}). "
|
| 225 |
+
f"Steps used: {ep_state.step_count}/50."
|
| 226 |
+
),
|
| 227 |
+
breakdown = {
|
| 228 |
+
"perf_improvement": round(perf_improvement, 4),
|
| 229 |
+
"step_efficiency": round(step_efficiency, 4),
|
| 230 |
+
"base_score": 0.10,
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Round 1 grading
|
| 235 |
score, breakdown, feedback = grade(request.action, request.task_id)
|
|
|
|
| 236 |
score = max(0.001, min(0.999, score))
|
| 237 |
+
return GraderResponse(score=score, feedback=feedback, breakdown=breakdown)
|
| 238 |
+
|
|
|
|
|
|
|
|
|
|
| 239 |
except Exception as e:
|
| 240 |
return GraderResponse(
|
| 241 |
score = 0.001,
|
|
|
|
| 243 |
breakdown = {"error": str(e)}
|
| 244 |
)
|
| 245 |
|
| 246 |
+
|
| 247 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
# 7. /baseline β POST
|
| 249 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 251 |
@app.post("/baseline", response_model=BaselineResponse, tags=["Baseline"])
|
| 252 |
async def baseline():
|
| 253 |
"""
|
| 254 |
+
Runs the baseline agent against all difficulty levels.
|
| 255 |
+
Must complete within 60 seconds.
|
|
|
|
| 256 |
"""
|
| 257 |
try:
|
| 258 |
import baseline as baseline_module
|
|
|
|
| 263 |
return results
|
| 264 |
except asyncio.TimeoutError:
|
| 265 |
return BaselineResponse(
|
| 266 |
+
results=[BaselineResult(
|
| 267 |
+
task_id="timeout", difficulty=DifficultyLevel.EASY,
|
| 268 |
+
score=0.0, steps=0, feedback="Baseline timed out."
|
| 269 |
+
)],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
average_score=0.0
|
| 271 |
)
|
| 272 |
except Exception as e:
|
| 273 |
return BaselineResponse(
|
| 274 |
+
results=[BaselineResult(
|
| 275 |
+
task_id="error", difficulty=DifficultyLevel.EASY,
|
| 276 |
+
score=0.0, steps=0, feedback=f"Baseline error: {str(e)}"
|
| 277 |
+
)],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
average_score=0.0
|
| 279 |
)
|
| 280 |
|
| 281 |
|
| 282 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
+
# 8. /progress β GET (Round 2 NEW)
|
| 284 |
+
# βββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββ
|
| 285 |
+
|
| 286 |
+
@app.get("/progress", response_model=ProgressResponse, tags=["Training"])
|
| 287 |
+
async def progress():
|
| 288 |
+
"""
|
| 289 |
+
Returns DB performance history for training visualization.
|
| 290 |
+
Used by evaluate_agent.py to generate reward curves.
|
| 291 |
+
Shows improvement from baseline to current score.
|
| 292 |
+
"""
|
| 293 |
+
ep_state = environment.state()
|
| 294 |
+
ac = ep_state.action_counts
|
| 295 |
+
perf_history = ac.get("_perf_history", [])
|
| 296 |
+
milestones = ac.get("_milestones", [])
|
| 297 |
+
baseline = ac.get("_baseline_score", 0.0)
|
| 298 |
+
target = ac.get("_target_score", 85.0)
|
| 299 |
+
best = ac.get("_best_score", 0.0)
|
| 300 |
+
current = perf_history[-1] if perf_history else 0.0
|
| 301 |
+
|
| 302 |
+
return ProgressResponse(
|
| 303 |
+
scenario_id = ep_state.task_id,
|
| 304 |
+
performance_score = current,
|
| 305 |
+
baseline_score = baseline,
|
| 306 |
+
target_score = target,
|
| 307 |
+
improvement_history = perf_history,
|
| 308 |
+
milestones_earned = milestones,
|
| 309 |
+
best_score = best,
|
| 310 |
+
steps_used = ep_state.step_count,
|
| 311 |
+
budget_remaining = max(0, 50 - ep_state.step_count),
|
| 312 |
+
total_reward = ep_state.total_reward,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
# ROOT
|
| 318 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 319 |
|
| 320 |
@app.get("/", tags=["System"])
|
| 321 |
async def root():
|
| 322 |
return {
|
| 323 |
+
"name": "SQL Database Engineer Agent β OpenEnv Environment",
|
| 324 |
+
"version": "2.0.0",
|
| 325 |
+
"tagline": "Training LLMs to act like senior database engineers",
|
| 326 |
"docs": "/docs",
|
| 327 |
"health": "/health",
|
| 328 |
+
"endpoints": ["/reset", "/step", "/state", "/tasks", "/grader", "/baseline", "/progress", "/health"],
|
| 329 |
+
"hackathon": "META x PyTorch x SST OpenEnv Hackathon β Finals April 25-26 Bangalore",
|
| 330 |
+
"domain": "Long-Horizon Database Engineering",
|
| 331 |
+
"tasks_count": 30,
|
| 332 |
+
"max_steps": 50,
|
| 333 |
+
"themes": ["Long-Horizon Planning", "World Modeling", "Self-Improvement", "Wildcard"],
|
| 334 |
+
}
|
blog/mini_blog.md
ADDED
|
File without changes
|
dataset/easy_scenarios.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "easy_s001",
|
| 4 |
+
"description": "User lookup query taking 2s on 10K users table. Missing index on email column.",
|
| 5 |
+
"tables": [
|
| 6 |
+
{"name": "users", "rows": 10000, "indexes": ["PRIMARY"], "size_mb": 8}
|
| 7 |
+
],
|
| 8 |
+
"slow_queries": [
|
| 9 |
+
{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000, "main_table": "users", "rows_examined": 10000}
|
| 10 |
+
],
|
| 11 |
+
"missing_index_hints": [
|
| 12 |
+
{"table": "users", "columns": ["email"], "reason": "email is used in WHERE clause but has no index"}
|
| 13 |
+
],
|
| 14 |
+
"performance_score_baseline": 8.0,
|
| 15 |
+
"target_score": 80.0,
|
| 16 |
+
"max_steps": 15,
|
| 17 |
+
"optimal_actions": ["inspect_query:q1", "analyze_indexes:users", "create_index:users:email", "submit_report"],
|
| 18 |
+
"category": "indexing"
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"id": "easy_s002",
|
| 22 |
+
"description": "Order status query scanning 50K orders. Composite index on user_id + status needed.",
|
| 23 |
+
"tables": [
|
| 24 |
+
{"name": "orders", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 120}
|
| 25 |
+
],
|
| 26 |
+
"slow_queries": [
|
| 27 |
+
{"id": "q1", "sql": "SELECT * FROM orders WHERE user_id=? AND status=?", "avg_ms": 3500, "main_table": "orders", "rows_examined": 50000}
|
| 28 |
+
],
|
| 29 |
+
"missing_index_hints": [
|
| 30 |
+
{"table": "orders", "columns": ["user_id", "status"], "reason": "Composite WHERE clause needs composite index"}
|
| 31 |
+
],
|
| 32 |
+
"performance_score_baseline": 5.0,
|
| 33 |
+
"target_score": 85.0,
|
| 34 |
+
"max_steps": 15,
|
| 35 |
+
"optimal_actions": ["inspect_query:q1", "create_index:orders:user_id,status", "submit_report"],
|
| 36 |
+
"category": "indexing"
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"id": "easy_s003",
|
| 40 |
+
"description": "Product search query doing full table scan on 20K products. Index on name column fixes it.",
|
| 41 |
+
"tables": [
|
| 42 |
+
{"name": "products", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 35}
|
| 43 |
+
],
|
| 44 |
+
"slow_queries": [
|
| 45 |
+
{"id": "q1", "sql": "SELECT id, name, price FROM products WHERE name LIKE ?", "avg_ms": 1800, "main_table": "products", "rows_examined": 20000}
|
| 46 |
+
],
|
| 47 |
+
"missing_index_hints": [
|
| 48 |
+
{"table": "products", "columns": ["name"], "reason": "LIKE queries benefit from index on name"}
|
| 49 |
+
],
|
| 50 |
+
"performance_score_baseline": 10.0,
|
| 51 |
+
"target_score": 78.0,
|
| 52 |
+
"max_steps": 15,
|
| 53 |
+
"optimal_actions": ["inspect_query:q1", "create_index:products:name", "submit_report"],
|
| 54 |
+
"category": "indexing"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"id": "easy_s004",
|
| 58 |
+
"description": "Session lookup hitting 15K sessions table without index. Single index solves it.",
|
| 59 |
+
"tables": [
|
| 60 |
+
{"name": "sessions", "rows": 15000, "indexes": ["PRIMARY"], "size_mb": 12}
|
| 61 |
+
],
|
| 62 |
+
"slow_queries": [
|
| 63 |
+
{"id": "q1", "sql": "SELECT * FROM sessions WHERE user_id=? AND expires_at > NOW()", "avg_ms": 1500, "main_table": "sessions", "rows_examined": 15000}
|
| 64 |
+
],
|
| 65 |
+
"missing_index_hints": [
|
| 66 |
+
{"table": "sessions", "columns": ["user_id", "expires_at"], "reason": "Composite index on user_id + expires_at needed"}
|
| 67 |
+
],
|
| 68 |
+
"performance_score_baseline": 12.0,
|
| 69 |
+
"target_score": 80.0,
|
| 70 |
+
"max_steps": 15,
|
| 71 |
+
"optimal_actions": ["inspect_query:q1", "create_index:sessions:user_id,expires_at", "submit_report"],
|
| 72 |
+
"category": "indexing"
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"id": "easy_s005",
|
| 76 |
+
"description": "Log table growing to 30K entries. Query filtering by level and created_at is slow.",
|
| 77 |
+
"tables": [
|
| 78 |
+
{"name": "logs", "rows": 30000, "indexes": ["PRIMARY"], "size_mb": 50}
|
| 79 |
+
],
|
| 80 |
+
"slow_queries": [
|
| 81 |
+
{"id": "q1", "sql": "SELECT * FROM logs WHERE level=? AND created_at > ?", "avg_ms": 2200, "main_table": "logs", "rows_examined": 30000}
|
| 82 |
+
],
|
| 83 |
+
"missing_index_hints": [
|
| 84 |
+
{"table": "logs", "columns": ["level", "created_at"], "reason": "Compound filter needs compound index"}
|
| 85 |
+
],
|
| 86 |
+
"performance_score_baseline": 7.8,
|
| 87 |
+
"target_score": 80.0,
|
| 88 |
+
"max_steps": 15,
|
| 89 |
+
"optimal_actions": ["inspect_query:q1", "create_index:logs:level,created_at", "submit_report"],
|
| 90 |
+
"category": "indexing"
|
| 91 |
+
}
|
| 92 |
+
]
|
dataset/hard_scenarios.json
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "hard_s001",
|
| 4 |
+
"description": "Financial DB: 500K transactions across 4 tables. 3 slow queries. Needs indexes, partition, and statistics.",
|
| 5 |
+
"tables": [
|
| 6 |
+
{"name": "transactions", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 2400},
|
| 7 |
+
{"name": "accounts", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 80},
|
| 8 |
+
{"name": "customers", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 120},
|
| 9 |
+
{"name": "audit_log", "rows": 1000000,"indexes": ["PRIMARY"], "size_mb": 5000}
|
| 10 |
+
],
|
| 11 |
+
"slow_queries": [
|
| 12 |
+
{"id": "q1", "sql": "SELECT * FROM transactions WHERE account_id=? AND status=? AND created_at > ?", "avg_ms": 15000, "main_table": "transactions", "rows_examined": 500000},
|
| 13 |
+
{"id": "q2", "sql": "SELECT c.*, COUNT(t.id) FROM customers c, transactions t WHERE c.id = t.customer_id AND t.amount > ? GROUP BY c.id", "avg_ms": 22000, "main_table": "transactions", "rows_examined": 500000},
|
| 14 |
+
{"id": "q3", "sql": "SELECT * FROM audit_log WHERE entity_id=? AND entity_type=? ORDER BY created_at DESC LIMIT 100", "avg_ms": 18000, "main_table": "audit_log", "rows_examined": 1000000}
|
| 15 |
+
],
|
| 16 |
+
"missing_index_hints": [
|
| 17 |
+
{"table": "transactions", "columns": ["account_id", "status", "created_at"], "reason": "Composite filter β high cardinality"},
|
| 18 |
+
{"table": "transactions", "columns": ["customer_id", "amount"], "reason": "JOIN + range filter"},
|
| 19 |
+
{"table": "audit_log", "columns": ["entity_id", "entity_type", "created_at"], "reason": "Lookup + ORDER BY on huge table"}
|
| 20 |
+
],
|
| 21 |
+
"performance_score_baseline": 4.2,
|
| 22 |
+
"target_score": 70.0,
|
| 23 |
+
"max_steps": 50,
|
| 24 |
+
"optimal_actions": [
|
| 25 |
+
"inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
|
| 26 |
+
"analyze_indexes:transactions", "analyze_indexes:audit_log",
|
| 27 |
+
"create_index:transactions:account_id,status,created_at",
|
| 28 |
+
"create_index:transactions:customer_id,amount",
|
| 29 |
+
"create_index:audit_log:entity_id,entity_type,created_at",
|
| 30 |
+
"rewrite_query:q2:SELECT c.id, c.name, COUNT(t.id) as tx_count FROM customers c INNER JOIN transactions t ON c.id = t.customer_id WHERE t.amount > ? GROUP BY c.id, c.name",
|
| 31 |
+
"partition_table:audit_log",
|
| 32 |
+
"analyze_statistics:transactions",
|
| 33 |
+
"analyze_statistics:audit_log",
|
| 34 |
+
"submit_report"
|
| 35 |
+
],
|
| 36 |
+
"category": "financial"
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"id": "hard_s002",
|
| 40 |
+
"description": "SaaS platform: 8-table schema, 200K+ records. Dashboard queries taking 20s+. Full optimization campaign.",
|
| 41 |
+
"tables": [
|
| 42 |
+
{"name": "workspaces", "rows": 5000, "indexes": ["PRIMARY"], "size_mb": 10},
|
| 43 |
+
{"name": "users", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 120},
|
| 44 |
+
{"name": "projects", "rows": 200000, "indexes": ["PRIMARY"], "size_mb": 450},
|
| 45 |
+
{"name": "tasks", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 3000},
|
| 46 |
+
{"name": "comments", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1800},
|
| 47 |
+
{"name": "attachments", "rows": 300000, "indexes": ["PRIMARY"], "size_mb": 900},
|
| 48 |
+
{"name": "activity_log", "rows": 2000000,"indexes": ["PRIMARY"], "size_mb": 8000},
|
| 49 |
+
{"name": "notifications", "rows": 400000, "indexes": ["PRIMARY"], "size_mb": 600}
|
| 50 |
+
],
|
| 51 |
+
"slow_queries": [
|
| 52 |
+
{"id": "q1", "sql": "SELECT * FROM tasks WHERE project_id=? AND assignee_id=? AND status != 'done' ORDER BY due_date ASC", "avg_ms": 20000, "main_table": "tasks", "rows_examined": 800000},
|
| 53 |
+
{"id": "q2", "sql": "SELECT * FROM activity_log WHERE workspace_id=? AND created_at > ? ORDER BY created_at DESC LIMIT 50", "avg_ms": 25000, "main_table": "activity_log", "rows_examined": 2000000},
|
| 54 |
+
{"id": "q3", "sql": "SELECT * FROM notifications WHERE user_id=? AND read=0", "avg_ms": 8000, "main_table": "notifications", "rows_examined": 400000}
|
| 55 |
+
],
|
| 56 |
+
"missing_index_hints": [
|
| 57 |
+
{"table": "tasks", "columns": ["project_id", "assignee_id", "status", "due_date"], "reason": "4-column filter + ORDER BY"},
|
| 58 |
+
{"table": "activity_log", "columns": ["workspace_id", "created_at"], "reason": "Range query on 2M row table β also partition candidate"},
|
| 59 |
+
{"table": "notifications","columns": ["user_id", "read"], "reason": "Hot path β unread notifications per user"}
|
| 60 |
+
],
|
| 61 |
+
"performance_score_baseline": 3.8,
|
| 62 |
+
"target_score": 68.0,
|
| 63 |
+
"max_steps": 50,
|
| 64 |
+
"optimal_actions": [
|
| 65 |
+
"inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
|
| 66 |
+
"analyze_indexes:tasks", "analyze_indexes:activity_log",
|
| 67 |
+
"create_index:tasks:project_id,assignee_id,status,due_date",
|
| 68 |
+
"create_index:activity_log:workspace_id,created_at",
|
| 69 |
+
"create_index:notifications:user_id,read",
|
| 70 |
+
"partition_table:activity_log",
|
| 71 |
+
"analyze_statistics:tasks",
|
| 72 |
+
"analyze_statistics:activity_log",
|
| 73 |
+
"submit_report"
|
| 74 |
+
],
|
| 75 |
+
"category": "saas_platform"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"id": "hard_s003",
|
| 79 |
+
"description": "Healthcare DB: 1M patient records. Compliance queries + clinical search + audit trail all slow.",
|
| 80 |
+
"tables": [
|
| 81 |
+
{"name": "patients", "rows": 1000000, "indexes": ["PRIMARY"], "size_mb": 4000},
|
| 82 |
+
{"name": "appointments", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1500},
|
| 83 |
+
{"name": "prescriptions", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 2500},
|
| 84 |
+
{"name": "clinical_notes", "rows": 1200000, "indexes": ["PRIMARY"], "size_mb": 6000}
|
| 85 |
+
],
|
| 86 |
+
"slow_queries": [
|
| 87 |
+
{"id": "q1", "sql": "SELECT * FROM appointments WHERE patient_id=? AND doctor_id=? AND appointment_date BETWEEN ? AND ?", "avg_ms": 18000, "main_table": "appointments", "rows_examined": 500000},
|
| 88 |
+
{"id": "q2", "sql": "SELECT * FROM prescriptions WHERE patient_id=? AND medication_code=? AND prescribed_at > ?", "avg_ms": 14000, "main_table": "prescriptions", "rows_examined": 800000},
|
| 89 |
+
{"id": "q3", "sql": "SELECT * FROM clinical_notes WHERE patient_id=? ORDER BY created_at DESC LIMIT 20", "avg_ms": 22000, "main_table": "clinical_notes", "rows_examined": 1200000}
|
| 90 |
+
],
|
| 91 |
+
"missing_index_hints": [
|
| 92 |
+
{"table": "appointments", "columns": ["patient_id", "doctor_id", "appointment_date"], "reason": "Date range query + 2 foreign keys"},
|
| 93 |
+
{"table": "prescriptions", "columns": ["patient_id", "medication_code", "prescribed_at"], "reason": "Patient medication history"},
|
| 94 |
+
{"table": "clinical_notes","columns": ["patient_id", "created_at"], "reason": "Sorted history per patient on 1.2M rows"}
|
| 95 |
+
],
|
| 96 |
+
"performance_score_baseline": 3.5,
|
| 97 |
+
"target_score": 68.0,
|
| 98 |
+
"max_steps": 50,
|
| 99 |
+
"optimal_actions": [
|
| 100 |
+
"inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
|
| 101 |
+
"analyze_indexes:appointments", "analyze_indexes:clinical_notes",
|
| 102 |
+
"create_index:appointments:patient_id,doctor_id,appointment_date",
|
| 103 |
+
"create_index:prescriptions:patient_id,medication_code,prescribed_at",
|
| 104 |
+
"create_index:clinical_notes:patient_id,created_at",
|
| 105 |
+
"partition_table:clinical_notes",
|
| 106 |
+
"analyze_statistics:appointments",
|
| 107 |
+
"analyze_statistics:clinical_notes",
|
| 108 |
+
"submit_report"
|
| 109 |
+
],
|
| 110 |
+
"category": "healthcare"
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"id": "hard_s004",
|
| 114 |
+
"description": "Gaming leaderboard: 2M player records. Real-time ranking + history + match queries all degraded.",
|
| 115 |
+
"tables": [
|
| 116 |
+
{"name": "players", "rows": 2000000, "indexes": ["PRIMARY"], "size_mb": 5000},
|
| 117 |
+
{"name": "matches", "rows": 5000000, "indexes": ["PRIMARY"], "size_mb": 15000},
|
| 118 |
+
{"name": "leaderboards", "rows": 2000000, "indexes": ["PRIMARY"], "size_mb": 4000},
|
| 119 |
+
{"name": "achievements", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 2000}
|
| 120 |
+
],
|
| 121 |
+
"slow_queries": [
|
| 122 |
+
{"id": "q1", "sql": "SELECT * FROM leaderboards WHERE game_mode=? AND season=? ORDER BY score DESC LIMIT 100", "avg_ms": 30000, "main_table": "leaderboards", "rows_examined": 2000000},
|
| 123 |
+
{"id": "q2", "sql": "SELECT * FROM matches WHERE player_id=? AND game_mode=? AND played_at > ? ORDER BY played_at DESC", "avg_ms": 25000, "main_table": "matches", "rows_examined": 5000000},
|
| 124 |
+
{"id": "q3", "sql": "SELECT * FROM achievements WHERE player_id=? AND unlocked=1", "avg_ms": 12000, "main_table": "achievements", "rows_examined": 800000}
|
| 125 |
+
],
|
| 126 |
+
"missing_index_hints": [
|
| 127 |
+
{"table": "leaderboards", "columns": ["game_mode", "season", "score"], "reason": "Sorted leaderboard by mode+season"},
|
| 128 |
+
{"table": "matches", "columns": ["player_id", "game_mode", "played_at"], "reason": "Player history β 5M rows"},
|
| 129 |
+
{"table": "achievements", "columns": ["player_id", "unlocked"], "reason": "Unlocked achievements per player"}
|
| 130 |
+
],
|
| 131 |
+
"performance_score_baseline": 2.8,
|
| 132 |
+
"target_score": 65.0,
|
| 133 |
+
"max_steps": 50,
|
| 134 |
+
"optimal_actions": [
|
| 135 |
+
"inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
|
| 136 |
+
"analyze_indexes:leaderboards", "analyze_indexes:matches",
|
| 137 |
+
"create_index:leaderboards:game_mode,season,score",
|
| 138 |
+
"create_index:matches:player_id,game_mode,played_at",
|
| 139 |
+
"create_index:achievements:player_id,unlocked",
|
| 140 |
+
"partition_table:matches",
|
| 141 |
+
"analyze_statistics:leaderboards",
|
| 142 |
+
"analyze_statistics:matches",
|
| 143 |
+
"submit_report"
|
| 144 |
+
],
|
| 145 |
+
"category": "gaming"
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"id": "hard_s005",
|
| 149 |
+
"description": "Logistics platform: 6 tables, 3M shipment records. ETA queries, route optimization, and reporting all slow.",
|
| 150 |
+
"tables": [
|
| 151 |
+
{"name": "shipments", "rows": 3000000, "indexes": ["PRIMARY"], "size_mb": 9000},
|
| 152 |
+
{"name": "routes", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1500},
|
| 153 |
+
{"name": "drivers", "rows": 100000, "indexes": ["PRIMARY"], "size_mb": 200},
|
| 154 |
+
{"name": "vehicles", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 150},
|
| 155 |
+
{"name": "warehouses", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 40},
|
| 156 |
+
{"name": "tracking", "rows": 10000000,"indexes": ["PRIMARY"], "size_mb": 30000}
|
| 157 |
+
],
|
| 158 |
+
"slow_queries": [
|
| 159 |
+
{"id": "q1", "sql": "SELECT * FROM shipments WHERE origin_warehouse=? AND status=? AND scheduled_at BETWEEN ? AND ?", "avg_ms": 28000, "main_table": "shipments", "rows_examined": 3000000},
|
| 160 |
+
{"id": "q2", "sql": "SELECT * FROM tracking WHERE shipment_id=? ORDER BY recorded_at DESC LIMIT 50", "avg_ms": 35000, "main_table": "tracking", "rows_examined": 10000000},
|
| 161 |
+
{"id": "q3", "sql": "SELECT d.*, COUNT(s.id) FROM drivers d, shipments s WHERE d.id = s.driver_id AND s.status='in_transit' GROUP BY d.id", "avg_ms": 20000, "main_table": "shipments", "rows_examined": 3000000}
|
| 162 |
+
],
|
| 163 |
+
"missing_index_hints": [
|
| 164 |
+
{"table": "shipments", "columns": ["origin_warehouse", "status", "scheduled_at"], "reason": "3-column filter on 3M rows"},
|
| 165 |
+
{"table": "tracking", "columns": ["shipment_id", "recorded_at"], "reason": "Lookup + sort on 10M row table β partition candidate"},
|
| 166 |
+
{"table": "shipments", "columns": ["driver_id", "status"], "reason": "JOIN + WHERE filter for driver stats"}
|
| 167 |
+
],
|
| 168 |
+
"performance_score_baseline": 2.5,
|
| 169 |
+
"target_score": 65.0,
|
| 170 |
+
"max_steps": 50,
|
| 171 |
+
"optimal_actions": [
|
| 172 |
+
"inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
|
| 173 |
+
"analyze_indexes:shipments", "analyze_indexes:tracking",
|
| 174 |
+
"create_index:shipments:origin_warehouse,status,scheduled_at",
|
| 175 |
+
"create_index:tracking:shipment_id,recorded_at",
|
| 176 |
+
"create_index:shipments:driver_id,status",
|
| 177 |
+
"rewrite_query:q3:SELECT d.id, d.name, COUNT(s.id) as active_shipments FROM drivers d INNER JOIN shipments s ON d.id = s.driver_id WHERE s.status='in_transit' GROUP BY d.id, d.name",
|
| 178 |
+
"partition_table:tracking",
|
| 179 |
+
"analyze_statistics:shipments",
|
| 180 |
+
"analyze_statistics:tracking",
|
| 181 |
+
"submit_report"
|
| 182 |
+
],
|
| 183 |
+
"category": "logistics"
|
| 184 |
+
}
|
| 185 |
+
]
|
dataset/medium_scenarios.json
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "medium_s001",
|
| 4 |
+
"description": "E-commerce DB: 50K orders + 8K users. Two slow queries. Composite indexes + statistics update needed.",
|
| 5 |
+
"tables": [
|
| 6 |
+
{"name": "orders", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 280},
|
| 7 |
+
{"name": "users", "rows": 8000, "indexes": ["PRIMARY", "email_idx"], "size_mb": 15}
|
| 8 |
+
],
|
| 9 |
+
"slow_queries": [
|
| 10 |
+
{"id": "q1", "sql": "SELECT * FROM orders WHERE user_id=? AND status=?", "avg_ms": 8500, "main_table": "orders", "rows_examined": 50000},
|
| 11 |
+
{"id": "q2", "sql": "SELECT COUNT(*) FROM orders o JOIN users u ON o.user_id=u.id WHERE u.country=?", "avg_ms": 3200, "main_table": "orders", "rows_examined": 50000}
|
| 12 |
+
],
|
| 13 |
+
"missing_index_hints": [
|
| 14 |
+
{"table": "orders", "columns": ["user_id", "status"], "reason": "Composite WHERE filter"},
|
| 15 |
+
{"table": "users", "columns": ["country"], "reason": "JOIN + WHERE filter on country"}
|
| 16 |
+
],
|
| 17 |
+
"performance_score_baseline": 12.5,
|
| 18 |
+
"target_score": 75.0,
|
| 19 |
+
"max_steps": 25,
|
| 20 |
+
"optimal_actions": [
|
| 21 |
+
"inspect_query:q1", "inspect_query:q2",
|
| 22 |
+
"analyze_indexes:orders", "analyze_indexes:users",
|
| 23 |
+
"create_index:orders:user_id,status",
|
| 24 |
+
"create_index:users:country",
|
| 25 |
+
"analyze_statistics:orders",
|
| 26 |
+
"submit_report"
|
| 27 |
+
],
|
| 28 |
+
"category": "multi_table"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"id": "medium_s002",
|
| 32 |
+
"description": "Blog platform: 100K posts + 20K authors. Search and author lookup queries both slow.",
|
| 33 |
+
"tables": [
|
| 34 |
+
{"name": "posts", "rows": 100000, "indexes": ["PRIMARY"], "size_mb": 450},
|
| 35 |
+
{"name": "authors", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 40}
|
| 36 |
+
],
|
| 37 |
+
"slow_queries": [
|
| 38 |
+
{"id": "q1", "sql": "SELECT * FROM posts WHERE author_id=? AND published=1 ORDER BY created_at DESC", "avg_ms": 6000, "main_table": "posts", "rows_examined": 100000},
|
| 39 |
+
{"id": "q2", "sql": "SELECT * FROM authors WHERE username=?", "avg_ms": 2100, "main_table": "authors", "rows_examined": 20000}
|
| 40 |
+
],
|
| 41 |
+
"missing_index_hints": [
|
| 42 |
+
{"table": "posts", "columns": ["author_id", "published", "created_at"], "reason": "Multi-column filter + ORDER BY"},
|
| 43 |
+
{"table": "authors", "columns": ["username"], "reason": "Unique lookup by username"}
|
| 44 |
+
],
|
| 45 |
+
"performance_score_baseline": 9.0,
|
| 46 |
+
"target_score": 78.0,
|
| 47 |
+
"max_steps": 25,
|
| 48 |
+
"optimal_actions": [
|
| 49 |
+
"inspect_query:q1", "inspect_query:q2",
|
| 50 |
+
"create_index:posts:author_id,published,created_at",
|
| 51 |
+
"create_index:authors:username",
|
| 52 |
+
"submit_report"
|
| 53 |
+
],
|
| 54 |
+
"category": "multi_table"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"id": "medium_s003",
|
| 58 |
+
"description": "Inventory system: 80K products + 200K stock movements. Two queries needing index + rewrite.",
|
| 59 |
+
"tables": [
|
| 60 |
+
{"name": "products", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 200},
|
| 61 |
+
{"name": "stock_movements", "rows": 200000, "indexes": ["PRIMARY"], "size_mb": 600}
|
| 62 |
+
],
|
| 63 |
+
"slow_queries": [
|
| 64 |
+
{"id": "q1", "sql": "SELECT * FROM stock_movements WHERE product_id=? AND movement_type=? AND created_at > ?", "avg_ms": 9000, "main_table": "stock_movements", "rows_examined": 200000},
|
| 65 |
+
{"id": "q2", "sql": "SELECT p.*, SUM(sm.quantity) FROM products p, stock_movements sm WHERE p.id = sm.product_id GROUP BY p.id", "avg_ms": 12000, "main_table": "products", "rows_examined": 200000}
|
| 66 |
+
],
|
| 67 |
+
"missing_index_hints": [
|
| 68 |
+
{"table": "stock_movements", "columns": ["product_id", "movement_type", "created_at"], "reason": "Composite filter on 3 columns"},
|
| 69 |
+
{"table": "products", "columns": ["id"], "reason": "JOIN column β rewrite implicit JOIN to INNER JOIN"}
|
| 70 |
+
],
|
| 71 |
+
"performance_score_baseline": 6.5,
|
| 72 |
+
"target_score": 72.0,
|
| 73 |
+
"max_steps": 30,
|
| 74 |
+
"optimal_actions": [
|
| 75 |
+
"inspect_query:q1", "inspect_query:q2",
|
| 76 |
+
"create_index:stock_movements:product_id,movement_type,created_at",
|
| 77 |
+
"rewrite_query:q2:SELECT p.id, p.name, SUM(sm.quantity) FROM products p INNER JOIN stock_movements sm ON p.id = sm.product_id GROUP BY p.id",
|
| 78 |
+
"analyze_statistics:stock_movements",
|
| 79 |
+
"submit_report"
|
| 80 |
+
],
|
| 81 |
+
"category": "rewrite_and_index"
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"id": "medium_s004",
|
| 85 |
+
"description": "Ticketing system: 60K tickets + 5K agents. Status queue and agent workload queries are slow.",
|
| 86 |
+
"tables": [
|
| 87 |
+
{"name": "tickets", "rows": 60000, "indexes": ["PRIMARY"], "size_mb": 180},
|
| 88 |
+
{"name": "agents", "rows": 5000, "indexes": ["PRIMARY"], "size_mb": 8}
|
| 89 |
+
],
|
| 90 |
+
"slow_queries": [
|
| 91 |
+
{"id": "q1", "sql": "SELECT * FROM tickets WHERE status=? AND priority=? ORDER BY created_at ASC", "avg_ms": 5500, "main_table": "tickets", "rows_examined": 60000},
|
| 92 |
+
{"id": "q2", "sql": "SELECT agent_id, COUNT(*) as open_count FROM tickets WHERE status='open' GROUP BY agent_id", "avg_ms": 4200, "main_table": "tickets", "rows_examined": 60000}
|
| 93 |
+
],
|
| 94 |
+
"missing_index_hints": [
|
| 95 |
+
{"table": "tickets", "columns": ["status", "priority", "created_at"], "reason": "Three-column filter with ORDER BY"},
|
| 96 |
+
{"table": "tickets", "columns": ["status", "agent_id"], "reason": "GROUP BY + WHERE filter"}
|
| 97 |
+
],
|
| 98 |
+
"performance_score_baseline": 11.0,
|
| 99 |
+
"target_score": 76.0,
|
| 100 |
+
"max_steps": 25,
|
| 101 |
+
"optimal_actions": [
|
| 102 |
+
"inspect_query:q1", "inspect_query:q2",
|
| 103 |
+
"analyze_indexes:tickets",
|
| 104 |
+
"create_index:tickets:status,priority,created_at",
|
| 105 |
+
"create_index:tickets:status,agent_id",
|
| 106 |
+
"submit_report"
|
| 107 |
+
],
|
| 108 |
+
"category": "multi_index"
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"id": "medium_s005",
|
| 112 |
+
"description": "Analytics DB: 150K events + 10K users. Event funnel query and user lookup both need optimization.",
|
| 113 |
+
"tables": [
|
| 114 |
+
{"name": "events", "rows": 150000, "indexes": ["PRIMARY"], "size_mb": 700},
|
| 115 |
+
{"name": "users", "rows": 10000, "indexes": ["PRIMARY"], "size_mb": 20}
|
| 116 |
+
],
|
| 117 |
+
"slow_queries": [
|
| 118 |
+
{"id": "q1", "sql": "SELECT * FROM events WHERE user_id=? AND event_type=? AND occurred_at BETWEEN ? AND ?", "avg_ms": 11000, "main_table": "events", "rows_examined": 150000},
|
| 119 |
+
{"id": "q2", "sql": "SELECT * FROM users WHERE signup_source=? AND created_at > ?", "avg_ms": 3000, "main_table": "users", "rows_examined": 10000}
|
| 120 |
+
],
|
| 121 |
+
"missing_index_hints": [
|
| 122 |
+
{"table": "events", "columns": ["user_id", "event_type", "occurred_at"], "reason": "Range query on 3 columns"},
|
| 123 |
+
{"table": "users", "columns": ["signup_source", "created_at"], "reason": "Composite filter on signup data"}
|
| 124 |
+
],
|
| 125 |
+
"performance_score_baseline": 5.5,
|
| 126 |
+
"target_score": 74.0,
|
| 127 |
+
"max_steps": 30,
|
| 128 |
+
"optimal_actions": [
|
| 129 |
+
"inspect_query:q1", "inspect_query:q2",
|
| 130 |
+
"create_index:events:user_id,event_type,occurred_at",
|
| 131 |
+
"create_index:users:signup_source,created_at",
|
| 132 |
+
"analyze_statistics:events",
|
| 133 |
+
"submit_report"
|
| 134 |
+
],
|
| 135 |
+
"category": "analytics"
|
| 136 |
+
}
|
| 137 |
+
]
|
env/__pycache__/models.cpython-312.pyc
CHANGED
|
Binary files a/env/__pycache__/models.cpython-312.pyc and b/env/__pycache__/models.cpython-312.pyc differ
|
|
|
env/environment.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
import random
|
| 3 |
from typing import Optional
|
|
@@ -9,29 +15,28 @@ from env.models import (
|
|
| 9 |
)
|
| 10 |
from env.tasks import task_manager
|
| 11 |
from env.reward import compute_reward, is_done, MAX_STEPS
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class SQLDebuggerEnvironment:
|
| 15 |
"""
|
| 16 |
-
OpenEnv-compliant SQL
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
- Dense reward signal at every step
|
| 25 |
-
- No state leakage between episodes
|
| 26 |
-
- Graceful handling of all edge cases
|
| 27 |
-
- Deterministic grading
|
| 28 |
-
- Thread-safe episode state
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self):
|
| 32 |
-
self._state
|
| 33 |
-
self._current_task
|
| 34 |
-
self._started_at
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
# reset() β Observation
|
|
@@ -39,14 +44,8 @@ class SQLDebuggerEnvironment:
|
|
| 39 |
|
| 40 |
def reset(self, difficulty: Optional[str] = None, task_id: Optional[str] = None) -> Observation:
|
| 41 |
"""
|
| 42 |
-
Starts a fresh episode. Clears ALL state
|
| 43 |
-
Loads
|
| 44 |
-
Returns the initial Observation the agent sees.
|
| 45 |
-
|
| 46 |
-
Edge cases handled:
|
| 47 |
-
- reset() called mid-episode β cleanly resets, no state leakage
|
| 48 |
-
- invalid difficulty β defaults to random
|
| 49 |
-
- dataset empty β raises ValueError with clear message
|
| 50 |
"""
|
| 51 |
|
| 52 |
# ββ Resolve difficulty ββββββββββββββββββββββββββββββββββββ
|
|
@@ -54,7 +53,6 @@ class SQLDebuggerEnvironment:
|
|
| 54 |
try:
|
| 55 |
diff_enum = DifficultyLevel(difficulty.lower())
|
| 56 |
except ValueError:
|
| 57 |
-
# Invalid difficulty β pick random
|
| 58 |
diff_enum = random.choice(list(DifficultyLevel))
|
| 59 |
else:
|
| 60 |
diff_enum = random.choice(list(DifficultyLevel))
|
|
@@ -65,7 +63,18 @@ class SQLDebuggerEnvironment:
|
|
| 65 |
except Exception as e:
|
| 66 |
raise ValueError(f"Failed to load task: {str(e)}")
|
| 67 |
|
| 68 |
-
# ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
self._current_task = task
|
| 70 |
self._started_at = time.time()
|
| 71 |
self._state = EpisodeState(
|
|
@@ -76,147 +85,194 @@ class SQLDebuggerEnvironment:
|
|
| 76 |
done = False,
|
| 77 |
hints_used = 0,
|
| 78 |
previous_actions = [],
|
| 79 |
-
action_counts = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
started_at = self._started_at,
|
| 81 |
last_reward = 0.0,
|
| 82 |
initialized = True,
|
| 83 |
)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
context = task_manager.build_observation_context(task)
|
| 87 |
-
return Observation(
|
| 88 |
-
task_id = task["id"],
|
| 89 |
-
task_description = task["description"],
|
| 90 |
-
current_context = context,
|
| 91 |
-
step_count = 0,
|
| 92 |
-
difficulty = diff_enum,
|
| 93 |
-
max_steps = MAX_STEPS,
|
| 94 |
-
hints_used = 0,
|
| 95 |
-
previous_actions = [],
|
| 96 |
-
metadata = {
|
| 97 |
-
"category": task.get("category", ""),
|
| 98 |
-
"estimated_steps": task.get("estimated_fix_steps", 5),
|
| 99 |
-
"started_at": self._started_at,
|
| 100 |
-
}
|
| 101 |
-
)
|
| 102 |
|
| 103 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
-
# step() β
|
| 105 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
|
| 107 |
def step(self, action: Optional[Action]) -> StepResponse:
|
| 108 |
"""
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
Edge cases handled:
|
| 113 |
-
- step() called before reset() β auto-resets
|
| 114 |
-
- null action β reward=-0.1, done=False, never crash
|
| 115 |
-
- malformed action payload β catches ValidationError
|
| 116 |
-
- agent loops (same action 3+ times) β loop penalty
|
| 117 |
-
- episode already done β returns terminal observation
|
| 118 |
-
- max steps reached β forces done=True
|
| 119 |
-
- extremely long payload β truncated in models.py
|
| 120 |
"""
|
| 121 |
|
| 122 |
# ββ Auto-reset if not initialized ββββββββββββββββββββββββ
|
| 123 |
if not self._state.initialized or self._current_task is None:
|
| 124 |
obs = self.reset()
|
| 125 |
return StepResponse(
|
| 126 |
-
observation=obs,
|
| 127 |
-
reward=Reward(score=0.5, breakdown={"auto_reset": True}, feedback="Environment auto-reset."),
|
| 128 |
-
done=False,
|
| 129 |
-
info={"auto_reset": True}
|
| 130 |
)
|
| 131 |
|
| 132 |
# ββ Episode already done ββββββββββββββββββββββββββββββββββ
|
| 133 |
if self._state.done:
|
| 134 |
obs = self._build_observation()
|
| 135 |
return StepResponse(
|
| 136 |
-
observation=obs,
|
| 137 |
-
reward=Reward(score=0.5, breakdown={"episode_done": True}, feedback="Episode
|
| 138 |
-
done=True,
|
| 139 |
-
info={"episode_done": True, "total_reward": self._state.total_reward}
|
| 140 |
)
|
| 141 |
|
| 142 |
-
# ββ Handle null
|
| 143 |
if action is None or action.payload is None:
|
| 144 |
self._state.step_count += 1
|
| 145 |
-
obs
|
| 146 |
-
reward = Reward(
|
| 147 |
-
|
| 148 |
-
breakdown={"invalid_action": 0.001},
|
| 149 |
-
feedback="Null or invalid action received."
|
| 150 |
-
)
|
| 151 |
-
self._state.last_reward = -0.1
|
| 152 |
-
self._state.total_reward = round(self._state.total_reward - 0.1, 4)
|
| 153 |
-
done = self._state.step_count >= MAX_STEPS
|
| 154 |
self._state.done = done
|
| 155 |
return StepResponse(observation=obs, reward=reward, done=done, info={"error": "null_action"})
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
|
| 160 |
-
except Exception:
|
| 161 |
-
action_type_val = "unknown"
|
| 162 |
|
| 163 |
# ββ Update step count βββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββ
|
| 164 |
self._state.step_count += 1
|
| 165 |
self._state.previous_actions.append(action_type_val)
|
| 166 |
-
self._state.action_counts[action_type_val] =
|
|
|
|
| 167 |
|
| 168 |
-
# ββ
|
| 169 |
-
if
|
| 170 |
self._state.hints_used += 1
|
| 171 |
-
# Inject hint into next observation context
|
| 172 |
hint_text = task_manager.get_hint(self._current_task, self._state.hints_used)
|
| 173 |
self._current_task["_last_hint"] = hint_text
|
| 174 |
|
| 175 |
-
# ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
reward = compute_reward(
|
| 177 |
-
action
|
| 178 |
-
task_id
|
| 179 |
-
difficulty
|
| 180 |
-
step_count
|
| 181 |
-
previous_actions
|
| 182 |
-
hints_used
|
| 183 |
-
estimated_steps
|
| 184 |
-
action_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
| 187 |
# ββ Update cumulative reward ββββββββββββββββββββββββββββββ
|
| 188 |
self._state.last_reward = reward.score
|
| 189 |
self._state.total_reward = round(self._state.total_reward + reward.score, 4)
|
| 190 |
|
| 191 |
-
# ββ Check done
|
|
|
|
|
|
|
|
|
|
| 192 |
done = is_done(
|
| 193 |
-
action_type
|
| 194 |
-
step_count
|
| 195 |
-
grader_score
|
|
|
|
| 196 |
)
|
| 197 |
self._state.done = done
|
| 198 |
|
| 199 |
-
# ββ Build
|
| 200 |
obs = self._build_observation()
|
| 201 |
|
| 202 |
-
# ββ
|
| 203 |
info = {
|
| 204 |
-
"step_count":
|
| 205 |
-
"total_reward":
|
| 206 |
-
"hints_used":
|
| 207 |
-
"
|
| 208 |
-
"
|
| 209 |
-
"
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
if done:
|
| 212 |
info["episode_summary"] = {
|
| 213 |
-
"total_steps":
|
| 214 |
-
"total_reward":
|
| 215 |
-
"hints_used":
|
| 216 |
-
"duration_sec":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
}
|
| 218 |
|
| 219 |
-
# Normalize reward
|
| 220 |
normalized_score = max(0.001, min(0.999, (reward.score + 1.0) / 2.0))
|
| 221 |
reward = Reward(
|
| 222 |
score=normalized_score,
|
|
@@ -231,13 +287,6 @@ class SQLDebuggerEnvironment:
|
|
| 231 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 232 |
|
| 233 |
def state(self) -> EpisodeState:
|
| 234 |
-
"""
|
| 235 |
-
Returns the full current state at any point.
|
| 236 |
-
Must be JSON-serializable. Must always reflect latest step.
|
| 237 |
-
|
| 238 |
-
Edge case: state() called before reset() β returns default empty state.
|
| 239 |
-
Never crashes.
|
| 240 |
-
"""
|
| 241 |
return self._state
|
| 242 |
|
| 243 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -245,13 +294,9 @@ class SQLDebuggerEnvironment:
|
|
| 245 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
|
| 247 |
def _build_observation(self) -> Observation:
|
| 248 |
-
"""
|
| 249 |
-
|
| 250 |
-
Injects hint into context if one was just requested.
|
| 251 |
-
CRITICAL: Never leaks fixed_query (ground truth) to agent.
|
| 252 |
-
"""
|
| 253 |
if self._current_task is None:
|
| 254 |
-
# Fallback safe observation
|
| 255 |
return Observation(
|
| 256 |
task_id = "none",
|
| 257 |
task_description = "No task loaded. Call reset() first.",
|
|
@@ -264,14 +309,33 @@ class SQLDebuggerEnvironment:
|
|
| 264 |
metadata = {}
|
| 265 |
)
|
| 266 |
|
|
|
|
| 267 |
context = task_manager.build_observation_context(self._current_task)
|
| 268 |
|
| 269 |
-
# Inject
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
if "_last_hint" in self._current_task:
|
| 271 |
context["last_hint"] = self._current_task["_last_hint"]
|
| 272 |
|
| 273 |
-
|
| 274 |
-
context["steps_remaining"] = MAX_STEPS - self._state.step_count
|
| 275 |
context["total_reward_so_far"] = self._state.total_reward
|
| 276 |
|
| 277 |
return Observation(
|
|
@@ -284,10 +348,11 @@ class SQLDebuggerEnvironment:
|
|
| 284 |
hints_used = self._state.hints_used,
|
| 285 |
previous_actions = self._state.previous_actions.copy(),
|
| 286 |
metadata = {
|
| 287 |
-
"category":
|
| 288 |
-
"
|
| 289 |
-
"
|
| 290 |
-
"
|
|
|
|
| 291 |
}
|
| 292 |
)
|
| 293 |
|
|
@@ -296,4 +361,4 @@ class SQLDebuggerEnvironment:
|
|
| 296 |
# SINGLETON INSTANCE (used by FastAPI)
|
| 297 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 298 |
|
| 299 |
-
environment = SQLDebuggerEnvironment()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
env/environment.py β SQL Database Engineer Agent (SDEA)
|
| 3 |
+
Round 2: Long-horizon DB optimization environment.
|
| 4 |
+
Agent manages a simulated production database over 50 steps.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
import time
|
| 8 |
import random
|
| 9 |
from typing import Optional
|
|
|
|
| 15 |
)
|
| 16 |
from env.tasks import task_manager
|
| 17 |
from env.reward import compute_reward, is_done, MAX_STEPS
|
| 18 |
+
from env.db_simulator import DatabaseSimulator
|
| 19 |
|
| 20 |
|
| 21 |
class SQLDebuggerEnvironment:
|
| 22 |
"""
|
| 23 |
+
OpenEnv-compliant SQL Database Engineer Agent Environment.
|
| 24 |
+
|
| 25 |
+
Round 2 evolution:
|
| 26 |
+
- 50-step long-horizon episodes (up from 20)
|
| 27 |
+
- 10 action types including DB-specific actions
|
| 28 |
+
- DatabaseSimulator tracks real performance score 0-100
|
| 29 |
+
- Milestone bonuses at 25%/50%/75% improvement
|
| 30 |
+
- Backward compatible with Round 1 actions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(self):
|
| 34 |
+
self._state = EpisodeState()
|
| 35 |
+
self._current_task = None
|
| 36 |
+
self._started_at = None
|
| 37 |
+
self._db_sim: Optional[DatabaseSimulator] = None
|
| 38 |
+
self._milestones_earned: set = set()
|
| 39 |
+
self._baseline_score: float = 0.0
|
| 40 |
|
| 41 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
# reset() β Observation
|
|
|
|
| 44 |
|
| 45 |
def reset(self, difficulty: Optional[str] = None, task_id: Optional[str] = None) -> Observation:
|
| 46 |
"""
|
| 47 |
+
Starts a fresh episode. Clears ALL state.
|
| 48 |
+
Loads scenario and initializes DatabaseSimulator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
|
| 51 |
# ββ Resolve difficulty ββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 53 |
try:
|
| 54 |
diff_enum = DifficultyLevel(difficulty.lower())
|
| 55 |
except ValueError:
|
|
|
|
| 56 |
diff_enum = random.choice(list(DifficultyLevel))
|
| 57 |
else:
|
| 58 |
diff_enum = random.choice(list(DifficultyLevel))
|
|
|
|
| 63 |
except Exception as e:
|
| 64 |
raise ValueError(f"Failed to load task: {str(e)}")
|
| 65 |
|
| 66 |
+
# ββ Initialize DatabaseSimulator ββββββββββββββββββββββββββ
|
| 67 |
+
# Only initialize for Round 2 scenarios (have 'tables' key)
|
| 68 |
+
if "tables" in task and "slow_queries" in task:
|
| 69 |
+
self._db_sim = DatabaseSimulator(task)
|
| 70 |
+
self._baseline_score = self._db_sim.get_performance_score()
|
| 71 |
+
else:
|
| 72 |
+
# Round 1 task β no DB simulator needed
|
| 73 |
+
self._db_sim = None
|
| 74 |
+
self._baseline_score = 0.0
|
| 75 |
+
self._milestones_earned = set()
|
| 76 |
+
|
| 77 |
+
# ββ Reset episode state βββββββββββββββββββββββββββββββββββ
|
| 78 |
self._current_task = task
|
| 79 |
self._started_at = time.time()
|
| 80 |
self._state = EpisodeState(
|
|
|
|
| 85 |
done = False,
|
| 86 |
hints_used = 0,
|
| 87 |
previous_actions = [],
|
| 88 |
+
action_counts = {
|
| 89 |
+
"_baseline_score": self._baseline_score,
|
| 90 |
+
"_target_score": task.get("target_score", 85.0),
|
| 91 |
+
"_milestones": [],
|
| 92 |
+
"_perf_history": [self._baseline_score],
|
| 93 |
+
"_best_score": self._baseline_score,
|
| 94 |
+
},
|
| 95 |
started_at = self._started_at,
|
| 96 |
last_reward = 0.0,
|
| 97 |
initialized = True,
|
| 98 |
)
|
| 99 |
|
| 100 |
+
return self._build_observation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
# step() β StepResponse
|
| 104 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
|
| 106 |
def step(self, action: Optional[Action]) -> StepResponse:
|
| 107 |
"""
|
| 108 |
+
Processes an action, updates DB simulator, computes reward.
|
| 109 |
+
Handles all Round 2 DB engineering actions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
"""
|
| 111 |
|
| 112 |
# ββ Auto-reset if not initialized ββββββββββββββββββββββββ
|
| 113 |
if not self._state.initialized or self._current_task is None:
|
| 114 |
obs = self.reset()
|
| 115 |
return StepResponse(
|
| 116 |
+
observation = obs,
|
| 117 |
+
reward = Reward(score=0.5, breakdown={"auto_reset": True}, feedback="Environment auto-reset."),
|
| 118 |
+
done = False,
|
| 119 |
+
info = {"auto_reset": True}
|
| 120 |
)
|
| 121 |
|
| 122 |
# ββ Episode already done ββββββββββββββββββββββββββββββββββ
|
| 123 |
if self._state.done:
|
| 124 |
obs = self._build_observation()
|
| 125 |
return StepResponse(
|
| 126 |
+
observation = obs,
|
| 127 |
+
reward = Reward(score=0.5, breakdown={"episode_done": True}, feedback="Episode finished. Call reset()."),
|
| 128 |
+
done = True,
|
| 129 |
+
info = {"episode_done": True, "total_reward": self._state.total_reward}
|
| 130 |
)
|
| 131 |
|
| 132 |
+
# ββ Handle null action ββββββββββββββββββββββββββββββββββββ
|
| 133 |
if action is None or action.payload is None:
|
| 134 |
self._state.step_count += 1
|
| 135 |
+
obs = self._build_observation()
|
| 136 |
+
reward = Reward(score=0.001, breakdown={"invalid_action": 0.001}, feedback="Null action.")
|
| 137 |
+
done = self._state.step_count >= MAX_STEPS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
self._state.done = done
|
| 139 |
return StepResponse(observation=obs, reward=reward, done=done, info={"error": "null_action"})
|
| 140 |
|
| 141 |
+
action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
|
| 142 |
+
action_type_enum = action.action_type
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# ββ Update step count βββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββ
|
| 145 |
self._state.step_count += 1
|
| 146 |
self._state.previous_actions.append(action_type_val)
|
| 147 |
+
self._state.action_counts[action_type_val] = \
|
| 148 |
+
self._state.action_counts.get(action_type_val, 0) + 1
|
| 149 |
|
| 150 |
+
# ββ Handle hints ββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
+
if action_type_enum == ActionType.REQUEST_HINT:
|
| 152 |
self._state.hints_used += 1
|
|
|
|
| 153 |
hint_text = task_manager.get_hint(self._current_task, self._state.hints_used)
|
| 154 |
self._current_task["_last_hint"] = hint_text
|
| 155 |
|
| 156 |
+
# ββ Apply DB action and get delta βββββββββββββββββββββββββ
|
| 157 |
+
db_delta = 0.0
|
| 158 |
+
current_score = self._baseline_score
|
| 159 |
+
action_info = {}
|
| 160 |
+
|
| 161 |
+
if self._db_sim is not None:
|
| 162 |
+
payload = action.payload or {}
|
| 163 |
+
|
| 164 |
+
if action_type_enum == ActionType.INSPECT_QUERY:
|
| 165 |
+
qid = payload.get("query_id", "q1")
|
| 166 |
+
action_info = self._db_sim.inspect_query(qid)
|
| 167 |
+
self._current_task["_last_inspect"] = action_info
|
| 168 |
+
# No score change β investigation action
|
| 169 |
+
|
| 170 |
+
elif action_type_enum == ActionType.ANALYZE_INDEXES:
|
| 171 |
+
table = payload.get("table", "")
|
| 172 |
+
action_info = self._db_sim.analyze_indexes(table)
|
| 173 |
+
self._current_task["_last_analysis"] = action_info
|
| 174 |
+
|
| 175 |
+
elif action_type_enum == ActionType.CREATE_INDEX:
|
| 176 |
+
result = self._db_sim.apply_action("create_index", payload)
|
| 177 |
+
db_delta = result["delta"]
|
| 178 |
+
action_info = result
|
| 179 |
+
|
| 180 |
+
elif action_type_enum == ActionType.REWRITE_QUERY:
|
| 181 |
+
result = self._db_sim.apply_action("rewrite_query", payload)
|
| 182 |
+
db_delta = result["delta"]
|
| 183 |
+
action_info = result
|
| 184 |
+
|
| 185 |
+
elif action_type_enum == ActionType.ADD_COLUMN:
|
| 186 |
+
result = self._db_sim.apply_action("add_column", payload)
|
| 187 |
+
db_delta = result["delta"]
|
| 188 |
+
action_info = result
|
| 189 |
+
|
| 190 |
+
elif action_type_enum == ActionType.DROP_INDEX:
|
| 191 |
+
result = self._db_sim.apply_action("drop_index", payload)
|
| 192 |
+
db_delta = result["delta"]
|
| 193 |
+
action_info = result
|
| 194 |
+
|
| 195 |
+
elif action_type_enum == ActionType.PARTITION_TABLE:
|
| 196 |
+
result = self._db_sim.apply_action("partition_table", payload)
|
| 197 |
+
db_delta = result["delta"]
|
| 198 |
+
action_info = result
|
| 199 |
+
|
| 200 |
+
elif action_type_enum == ActionType.ANALYZE_STATS:
|
| 201 |
+
result = self._db_sim.apply_action("analyze_statistics", payload)
|
| 202 |
+
db_delta = result["delta"]
|
| 203 |
+
action_info = result
|
| 204 |
+
|
| 205 |
+
current_score = self._db_sim.get_performance_score()
|
| 206 |
+
|
| 207 |
+
# Update tracking in action_counts dict (used by /progress)
|
| 208 |
+
perf_history = self._state.action_counts.get("_perf_history", [])
|
| 209 |
+
perf_history.append(current_score)
|
| 210 |
+
self._state.action_counts["_perf_history"] = perf_history
|
| 211 |
+
self._state.action_counts["_best_score"] = self._db_sim.best_score
|
| 212 |
+
|
| 213 |
+
# ββ Compute reward ββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
reward = compute_reward(
|
| 215 |
+
action = action,
|
| 216 |
+
task_id = self._state.task_id,
|
| 217 |
+
difficulty = self._state.difficulty,
|
| 218 |
+
step_count = self._state.step_count,
|
| 219 |
+
previous_actions = self._state.previous_actions[:-1],
|
| 220 |
+
hints_used = self._state.hints_used,
|
| 221 |
+
estimated_steps = self._current_task.get("estimated_fix_steps", MAX_STEPS),
|
| 222 |
+
action_counts = self._state.action_counts,
|
| 223 |
+
db_delta = db_delta,
|
| 224 |
+
baseline_score = self._baseline_score,
|
| 225 |
+
current_score = current_score,
|
| 226 |
+
milestones_earned = self._milestones_earned,
|
| 227 |
)
|
| 228 |
|
| 229 |
+
# Update milestone tracking
|
| 230 |
+
self._state.action_counts["_milestones"] = list(self._milestones_earned)
|
| 231 |
+
|
| 232 |
# ββ Update cumulative reward ββββββββββββββββββββββββββββββ
|
| 233 |
self._state.last_reward = reward.score
|
| 234 |
self._state.total_reward = round(self._state.total_reward + reward.score, 4)
|
| 235 |
|
| 236 |
+
# ββ Check done ββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
+
target_reached = (
|
| 238 |
+
self._db_sim.is_target_reached() if self._db_sim else False
|
| 239 |
+
)
|
| 240 |
done = is_done(
|
| 241 |
+
action_type = action_type_enum,
|
| 242 |
+
step_count = self._state.step_count,
|
| 243 |
+
grader_score = reward.breakdown.get("grader_score", 0.0),
|
| 244 |
+
target_reached = target_reached,
|
| 245 |
)
|
| 246 |
self._state.done = done
|
| 247 |
|
| 248 |
+
# ββ Build observation βββββββββββββββββββββββββββββββββββββ
|
| 249 |
obs = self._build_observation()
|
| 250 |
|
| 251 |
+
# ββ Info dict βββββββββββββββββββββββββββββββββββββββββββββ
|
| 252 |
info = {
|
| 253 |
+
"step_count": self._state.step_count,
|
| 254 |
+
"total_reward": self._state.total_reward,
|
| 255 |
+
"hints_used": self._state.hints_used,
|
| 256 |
+
"task_id": self._state.task_id,
|
| 257 |
+
"difficulty": self._state.difficulty.value if self._state.difficulty else None,
|
| 258 |
+
"performance_score": current_score,
|
| 259 |
+
"db_delta": db_delta,
|
| 260 |
+
"milestones": list(self._milestones_earned),
|
| 261 |
+
"action_result": action_info,
|
| 262 |
}
|
| 263 |
if done:
|
| 264 |
info["episode_summary"] = {
|
| 265 |
+
"total_steps": self._state.step_count,
|
| 266 |
+
"total_reward": self._state.total_reward,
|
| 267 |
+
"hints_used": self._state.hints_used,
|
| 268 |
+
"duration_sec": round(time.time() - (self._started_at or time.time()), 2),
|
| 269 |
+
"final_score": current_score,
|
| 270 |
+
"baseline_score": self._baseline_score,
|
| 271 |
+
"improvement": round(current_score - self._baseline_score, 2),
|
| 272 |
+
"milestones_earned": list(self._milestones_earned),
|
| 273 |
}
|
| 274 |
|
| 275 |
+
# Normalize reward for validator compliance
|
| 276 |
normalized_score = max(0.001, min(0.999, (reward.score + 1.0) / 2.0))
|
| 277 |
reward = Reward(
|
| 278 |
score=normalized_score,
|
|
|
|
| 287 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
|
| 289 |
def state(self) -> EpisodeState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
return self._state
|
| 291 |
|
| 292 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 294 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
|
| 296 |
def _build_observation(self) -> Observation:
|
| 297 |
+
"""Builds Observation from current state + DB simulator state."""
|
| 298 |
+
|
|
|
|
|
|
|
|
|
|
| 299 |
if self._current_task is None:
|
|
|
|
| 300 |
return Observation(
|
| 301 |
task_id = "none",
|
| 302 |
task_description = "No task loaded. Call reset() first.",
|
|
|
|
| 309 |
metadata = {}
|
| 310 |
)
|
| 311 |
|
| 312 |
+
# Base context from task
|
| 313 |
context = task_manager.build_observation_context(self._current_task)
|
| 314 |
|
| 315 |
+
# Inject DB simulator state
|
| 316 |
+
if self._db_sim is not None:
|
| 317 |
+
db_state = self._db_sim.get_current_state()
|
| 318 |
+
context.update({
|
| 319 |
+
"performance_score": db_state["performance_score"],
|
| 320 |
+
"target_score": db_state["target_score"],
|
| 321 |
+
"baseline_score": db_state["baseline_score"],
|
| 322 |
+
"tables": db_state["tables"],
|
| 323 |
+
"slow_queries": db_state["slow_queries"],
|
| 324 |
+
"indexes": db_state["indexes"],
|
| 325 |
+
"improvement_history": db_state["history"],
|
| 326 |
+
"best_score": db_state["best_score"],
|
| 327 |
+
"milestones_earned": list(self._milestones_earned),
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
# Inject last action result if available
|
| 331 |
+
if "_last_inspect" in self._current_task:
|
| 332 |
+
context["last_inspect_result"] = self._current_task["_last_inspect"]
|
| 333 |
+
if "_last_analysis" in self._current_task:
|
| 334 |
+
context["last_analysis_result"] = self._current_task["_last_analysis"]
|
| 335 |
if "_last_hint" in self._current_task:
|
| 336 |
context["last_hint"] = self._current_task["_last_hint"]
|
| 337 |
|
| 338 |
+
context["steps_remaining"] = MAX_STEPS - self._state.step_count
|
|
|
|
| 339 |
context["total_reward_so_far"] = self._state.total_reward
|
| 340 |
|
| 341 |
return Observation(
|
|
|
|
| 348 |
hints_used = self._state.hints_used,
|
| 349 |
previous_actions = self._state.previous_actions.copy(),
|
| 350 |
metadata = {
|
| 351 |
+
"category": self._current_task.get("category", ""),
|
| 352 |
+
"baseline_score": self._baseline_score,
|
| 353 |
+
"target_score": self._current_task.get("target_score", 85.0),
|
| 354 |
+
"total_reward": self._state.total_reward,
|
| 355 |
+
"milestones": list(self._milestones_earned),
|
| 356 |
}
|
| 357 |
)
|
| 358 |
|
|
|
|
| 361 |
# SINGLETON INSTANCE (used by FastAPI)
|
| 362 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 363 |
|
| 364 |
+
environment = SQLDebuggerEnvironment()
|
env/models.py
CHANGED
|
@@ -4,7 +4,9 @@ from enum import Enum
|
|
| 4 |
import time
|
| 5 |
|
| 6 |
|
|
|
|
| 7 |
# ENUMS
|
|
|
|
| 8 |
|
| 9 |
class DifficultyLevel(str, Enum):
|
| 10 |
EASY = "easy"
|
|
@@ -13,42 +15,57 @@ class DifficultyLevel(str, Enum):
|
|
| 13 |
|
| 14 |
|
| 15 |
class ActionType(str, Enum):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class Observation(BaseModel):
|
| 27 |
-
task_id: str
|
| 28 |
-
task_description: str
|
| 29 |
-
current_context: dict
|
| 30 |
-
step_count: int
|
| 31 |
difficulty: DifficultyLevel = Field(..., description="Task difficulty level")
|
| 32 |
-
max_steps: int
|
| 33 |
-
hints_used: int
|
| 34 |
-
previous_actions: list[str]
|
| 35 |
-
metadata: dict
|
| 36 |
|
| 37 |
model_config = {"json_schema_extra": {
|
| 38 |
"example": {
|
| 39 |
-
"task_id": "
|
| 40 |
-
"task_description": "
|
| 41 |
"current_context": {
|
| 42 |
-
"
|
| 43 |
-
"
|
| 44 |
-
"
|
|
|
|
| 45 |
},
|
| 46 |
"step_count": 0,
|
| 47 |
"difficulty": "easy",
|
| 48 |
-
"max_steps":
|
| 49 |
"hints_used": 0,
|
| 50 |
"previous_actions": [],
|
| 51 |
-
"metadata": {"
|
| 52 |
}
|
| 53 |
}}
|
| 54 |
|
|
@@ -67,7 +84,6 @@ class Action(BaseModel):
|
|
| 67 |
@field_validator("payload")
|
| 68 |
@classmethod
|
| 69 |
def truncate_long_strings(cls, v):
|
| 70 |
-
# Edge case: extremely long agent output β truncate gracefully
|
| 71 |
def truncate(obj, max_len=5000):
|
| 72 |
if isinstance(obj, str) and len(obj) > max_len:
|
| 73 |
return obj[:max_len] + "...[truncated]"
|
|
@@ -78,12 +94,10 @@ class Action(BaseModel):
|
|
| 78 |
|
| 79 |
model_config = {"json_schema_extra": {
|
| 80 |
"example": {
|
| 81 |
-
"action_type": "
|
| 82 |
"payload": {
|
| 83 |
-
"
|
| 84 |
-
"
|
| 85 |
-
"error_type": "syntax",
|
| 86 |
-
"confidence": 0.95
|
| 87 |
}
|
| 88 |
}
|
| 89 |
}}
|
|
@@ -103,49 +117,53 @@ class Reward(BaseModel):
|
|
| 103 |
"example": {
|
| 104 |
"score": 0.75,
|
| 105 |
"breakdown": {
|
| 106 |
-
"
|
| 107 |
-
"
|
| 108 |
-
"
|
| 109 |
-
"
|
| 110 |
},
|
| 111 |
-
"feedback": "
|
| 112 |
}
|
| 113 |
}}
|
| 114 |
|
| 115 |
|
|
|
|
| 116 |
# EPISODE STATE (used by state() endpoint)
|
|
|
|
| 117 |
|
| 118 |
class EpisodeState(BaseModel):
|
| 119 |
-
task_id: Optional[str]
|
| 120 |
difficulty: Optional[DifficultyLevel] = Field(default=None)
|
| 121 |
step_count: int = Field(default=0)
|
| 122 |
total_reward: float = Field(default=0.0)
|
| 123 |
done: bool = Field(default=False)
|
| 124 |
hints_used: int = Field(default=0)
|
| 125 |
previous_actions: list[str] = Field(default_factory=list)
|
| 126 |
-
action_counts: dict[str,
|
| 127 |
started_at: Optional[float] = Field(default=None)
|
| 128 |
last_reward: float = Field(default=0.0)
|
| 129 |
initialized: bool = Field(default=False)
|
| 130 |
|
| 131 |
model_config = {"json_schema_extra": {
|
| 132 |
"example": {
|
| 133 |
-
"task_id": "
|
| 134 |
-
"difficulty": "
|
| 135 |
"step_count": 3,
|
| 136 |
-
"total_reward": 0.
|
| 137 |
"done": False,
|
| 138 |
-
"hints_used":
|
| 139 |
-
"previous_actions": ["
|
| 140 |
-
"action_counts": {"
|
| 141 |
"started_at": 1700000000.0,
|
| 142 |
-
"last_reward": 0.
|
| 143 |
"initialized": True
|
| 144 |
}
|
| 145 |
}}
|
| 146 |
|
| 147 |
|
|
|
|
| 148 |
# API REQUEST / RESPONSE WRAPPERS
|
|
|
|
| 149 |
|
| 150 |
class StepResponse(BaseModel):
|
| 151 |
observation: Observation
|
|
@@ -157,15 +175,15 @@ class ResetResponse(BaseModel):
|
|
| 157 |
observation: Observation
|
| 158 |
|
| 159 |
class TaskInfo(BaseModel):
|
| 160 |
-
id:
|
| 161 |
-
difficulty:
|
| 162 |
-
description:
|
| 163 |
-
action_schema: dict
|
| 164 |
|
| 165 |
class TaskListResponse(BaseModel):
|
| 166 |
-
tasks:
|
| 167 |
-
total:
|
| 168 |
-
action_types:
|
| 169 |
|
| 170 |
class BaselineResult(BaseModel):
|
| 171 |
task_id: str
|
|
@@ -180,7 +198,7 @@ class BaselineResult(BaseModel):
|
|
| 180 |
return max(0.001, min(0.999, round(float(v), 4)))
|
| 181 |
|
| 182 |
class BaselineResponse(BaseModel):
|
| 183 |
-
results:
|
| 184 |
average_score: float
|
| 185 |
completed_at: float = Field(default_factory=time.time)
|
| 186 |
|
|
@@ -201,13 +219,30 @@ class GraderResponse(BaseModel):
|
|
| 201 |
|
| 202 |
model_config = {"json_schema_extra": {
|
| 203 |
"example": {
|
| 204 |
-
"score":
|
| 205 |
-
"feedback": "
|
| 206 |
-
"breakdown": {"
|
| 207 |
}
|
| 208 |
}}
|
| 209 |
|
| 210 |
class HealthResponse(BaseModel):
|
| 211 |
-
status: str
|
| 212 |
-
version: str
|
| 213 |
-
uptime: float = Field(default_factory=time.time)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import time
|
| 5 |
|
| 6 |
|
| 7 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 8 |
# ENUMS
|
| 9 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
|
| 11 |
class DifficultyLevel(str, Enum):
|
| 12 |
EASY = "easy"
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class ActionType(str, Enum):
|
| 18 |
+
# ββ Round 1 actions (keep β backward compatible) ββ
|
| 19 |
+
IDENTIFY_ERROR = "identify_error"
|
| 20 |
+
PROPOSE_FIX = "propose_fix"
|
| 21 |
+
SUBMIT_ANSWER = "submit_answer"
|
| 22 |
+
REQUEST_HINT = "request_hint"
|
| 23 |
+
EXPLAIN_ISSUE = "explain_issue"
|
| 24 |
+
OPTIMIZE_QUERY = "optimize_query"
|
| 25 |
+
|
| 26 |
+
# ββ Round 2 new actions ββ
|
| 27 |
+
INSPECT_QUERY = "inspect_query"
|
| 28 |
+
ANALYZE_INDEXES = "analyze_indexes"
|
| 29 |
+
CREATE_INDEX = "create_index"
|
| 30 |
+
REWRITE_QUERY = "rewrite_query"
|
| 31 |
+
ADD_COLUMN = "add_column"
|
| 32 |
+
DROP_INDEX = "drop_index"
|
| 33 |
+
PARTITION_TABLE = "partition_table"
|
| 34 |
+
ANALYZE_STATS = "analyze_statistics"
|
| 35 |
+
SUBMIT_REPORT = "submit_report"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
# CORE MODELS
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
|
| 42 |
class Observation(BaseModel):
|
| 43 |
+
task_id: str = Field(..., description="Unique task identifier")
|
| 44 |
+
task_description: str = Field(..., description="What the agent must do")
|
| 45 |
+
current_context: dict = Field(..., description="What the agent currently sees")
|
| 46 |
+
step_count: int = Field(default=0, ge=0, description="Steps taken so far")
|
| 47 |
difficulty: DifficultyLevel = Field(..., description="Task difficulty level")
|
| 48 |
+
max_steps: int = Field(default=50, description="Maximum steps allowed")
|
| 49 |
+
hints_used: int = Field(default=0, description="Number of hints used")
|
| 50 |
+
previous_actions: list[str] = Field(default_factory=list, description="History of action types taken")
|
| 51 |
+
metadata: dict = Field(default_factory=dict, description="Extra task metadata")
|
| 52 |
|
| 53 |
model_config = {"json_schema_extra": {
|
| 54 |
"example": {
|
| 55 |
+
"task_id": "easy_s001",
|
| 56 |
+
"task_description": "Optimize a slow user lookup query on 10K users table.",
|
| 57 |
"current_context": {
|
| 58 |
+
"tables": [{"name": "users", "rows": 10000, "indexes": ["PRIMARY"]}],
|
| 59 |
+
"slow_queries": [{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000}],
|
| 60 |
+
"performance_score": 8.0,
|
| 61 |
+
"target_score": 80.0
|
| 62 |
},
|
| 63 |
"step_count": 0,
|
| 64 |
"difficulty": "easy",
|
| 65 |
+
"max_steps": 50,
|
| 66 |
"hints_used": 0,
|
| 67 |
"previous_actions": [],
|
| 68 |
+
"metadata": {"scenario_id": "easy_s001", "baseline_score": 8.0}
|
| 69 |
}
|
| 70 |
}}
|
| 71 |
|
|
|
|
| 84 |
@field_validator("payload")
|
| 85 |
@classmethod
|
| 86 |
def truncate_long_strings(cls, v):
|
|
|
|
| 87 |
def truncate(obj, max_len=5000):
|
| 88 |
if isinstance(obj, str) and len(obj) > max_len:
|
| 89 |
return obj[:max_len] + "...[truncated]"
|
|
|
|
| 94 |
|
| 95 |
model_config = {"json_schema_extra": {
|
| 96 |
"example": {
|
| 97 |
+
"action_type": "create_index",
|
| 98 |
"payload": {
|
| 99 |
+
"table": "users",
|
| 100 |
+
"columns": ["email"]
|
|
|
|
|
|
|
| 101 |
}
|
| 102 |
}
|
| 103 |
}}
|
|
|
|
| 117 |
"example": {
|
| 118 |
"score": 0.75,
|
| 119 |
"breakdown": {
|
| 120 |
+
"step_reward": 0.05,
|
| 121 |
+
"delta_reward": 0.40,
|
| 122 |
+
"milestone_bonus": 0.15,
|
| 123 |
+
"total": 0.60
|
| 124 |
},
|
| 125 |
+
"feedback": "Index created. Performance improved 55%. Milestone bonus earned!"
|
| 126 |
}
|
| 127 |
}}
|
| 128 |
|
| 129 |
|
| 130 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
# EPISODE STATE (used by state() endpoint)
|
| 132 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
|
| 134 |
class EpisodeState(BaseModel):
|
| 135 |
+
task_id: Optional[str] = Field(default=None)
|
| 136 |
difficulty: Optional[DifficultyLevel] = Field(default=None)
|
| 137 |
step_count: int = Field(default=0)
|
| 138 |
total_reward: float = Field(default=0.0)
|
| 139 |
done: bool = Field(default=False)
|
| 140 |
hints_used: int = Field(default=0)
|
| 141 |
previous_actions: list[str] = Field(default_factory=list)
|
| 142 |
+
action_counts: dict[str, Any] = Field(default_factory=dict)
|
| 143 |
started_at: Optional[float] = Field(default=None)
|
| 144 |
last_reward: float = Field(default=0.0)
|
| 145 |
initialized: bool = Field(default=False)
|
| 146 |
|
| 147 |
model_config = {"json_schema_extra": {
|
| 148 |
"example": {
|
| 149 |
+
"task_id": "easy_s001",
|
| 150 |
+
"difficulty": "easy",
|
| 151 |
"step_count": 3,
|
| 152 |
+
"total_reward": 0.65,
|
| 153 |
"done": False,
|
| 154 |
+
"hints_used": 0,
|
| 155 |
+
"previous_actions": ["inspect_query", "analyze_indexes", "create_index"],
|
| 156 |
+
"action_counts": {"inspect_query": 1, "analyze_indexes": 1, "create_index": 1},
|
| 157 |
"started_at": 1700000000.0,
|
| 158 |
+
"last_reward": 0.45,
|
| 159 |
"initialized": True
|
| 160 |
}
|
| 161 |
}}
|
| 162 |
|
| 163 |
|
| 164 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
# API REQUEST / RESPONSE WRAPPERS
|
| 166 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
|
| 168 |
class StepResponse(BaseModel):
|
| 169 |
observation: Observation
|
|
|
|
| 175 |
observation: Observation
|
| 176 |
|
| 177 |
class TaskInfo(BaseModel):
|
| 178 |
+
id: str
|
| 179 |
+
difficulty: DifficultyLevel
|
| 180 |
+
description: str
|
| 181 |
+
action_schema: dict
|
| 182 |
|
| 183 |
class TaskListResponse(BaseModel):
|
| 184 |
+
tasks: list[TaskInfo]
|
| 185 |
+
total: int
|
| 186 |
+
action_types: list[str]
|
| 187 |
|
| 188 |
class BaselineResult(BaseModel):
|
| 189 |
task_id: str
|
|
|
|
| 198 |
return max(0.001, min(0.999, round(float(v), 4)))
|
| 199 |
|
| 200 |
class BaselineResponse(BaseModel):
|
| 201 |
+
results: list[BaselineResult]
|
| 202 |
average_score: float
|
| 203 |
completed_at: float = Field(default_factory=time.time)
|
| 204 |
|
|
|
|
| 219 |
|
| 220 |
model_config = {"json_schema_extra": {
|
| 221 |
"example": {
|
| 222 |
+
"score": 0.82,
|
| 223 |
+
"feedback": "Performance improved from 12.5 to 85.0. Excellent optimization!",
|
| 224 |
+
"breakdown": {"perf_improvement": 0.60, "step_efficiency": 0.12, "index_quality": 0.10}
|
| 225 |
}
|
| 226 |
}}
|
| 227 |
|
| 228 |
class HealthResponse(BaseModel):
|
| 229 |
+
status: str = "ok"
|
| 230 |
+
version: str = "2.0.0"
|
| 231 |
+
uptime: float = Field(default_factory=time.time)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
# ROUND 2 β PROGRESS RESPONSE
|
| 236 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
+
|
| 238 |
+
class ProgressResponse(BaseModel):
|
| 239 |
+
scenario_id: Optional[str] = Field(default=None)
|
| 240 |
+
performance_score: float = Field(default=0.0, description="Current DB performance score 0-100")
|
| 241 |
+
baseline_score: float = Field(default=0.0, description="Starting score this episode")
|
| 242 |
+
target_score: float = Field(default=85.0, description="Score needed to succeed")
|
| 243 |
+
improvement_history: list[float] = Field(default_factory=list)
|
| 244 |
+
milestones_earned: list[float] = Field(default_factory=list)
|
| 245 |
+
best_score: float = Field(default=0.0)
|
| 246 |
+
steps_used: int = Field(default=0)
|
| 247 |
+
budget_remaining: int = Field(default=50)
|
| 248 |
+
total_reward: float = Field(default=0.0)
|
env/reward.py
CHANGED
|
@@ -1,41 +1,94 @@
|
|
| 1 |
from env.models import Action, Reward, DifficultyLevel, ActionType
|
| 2 |
from env.graders import grade
|
| 3 |
|
|
|
|
| 4 |
# CONSTANTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
HINT_PENALTY = -0.05 # Per hint requested
|
| 8 |
-
LOOP_PENALTY = -0.05 # Same action 3+ times in a row
|
| 9 |
-
INVALID_PENALTY = -0.10 # Null / malformed action
|
| 10 |
-
STEP_EFFICIENCY_BONUS = 0.10 # Bonus for solving in fewer steps than estimated
|
| 11 |
-
|
| 12 |
-
# Dense reward per action type (before grader score)
|
| 13 |
STEP_REWARDS = {
|
| 14 |
-
|
| 15 |
-
ActionType.
|
| 16 |
-
ActionType.
|
| 17 |
-
ActionType.
|
| 18 |
-
ActionType.
|
| 19 |
-
ActionType.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
# LOOP DETECTOR
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
Returns
|
| 29 |
-
|
| 30 |
"""
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
return False
|
| 33 |
-
|
| 34 |
-
return
|
| 35 |
|
| 36 |
|
| 37 |
def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
|
| 38 |
-
"""Count how many times the current action has been repeated consecutively."""
|
| 39 |
count = 1
|
| 40 |
for a in reversed(previous_actions):
|
| 41 |
if a == current_action:
|
|
@@ -45,24 +98,22 @@ def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
|
|
| 45 |
return count
|
| 46 |
|
| 47 |
|
|
|
|
| 48 |
# EFFICIENCY BONUS
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _efficiency_bonus(step_count: int,
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
return 0.0
|
| 58 |
-
if step_count <= estimated_steps:
|
| 59 |
-
ratio = step_count / estimated_steps
|
| 60 |
-
# More bonus the faster β scales from 0.10 down to 0.0
|
| 61 |
-
return round(STEP_EFFICIENCY_BONUS * (1.0 - ratio + 0.1), 4)
|
| 62 |
return 0.0
|
| 63 |
|
| 64 |
|
|
|
|
| 65 |
# MAIN REWARD FUNCTION
|
|
|
|
| 66 |
|
| 67 |
def compute_reward(
|
| 68 |
action: Action,
|
|
@@ -73,25 +124,33 @@ def compute_reward(
|
|
| 73 |
hints_used: int,
|
| 74 |
estimated_steps: int,
|
| 75 |
action_counts: dict[str, int],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
) -> Reward:
|
| 77 |
"""
|
| 78 |
-
Computes
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
"""
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
feedback_parts = []
|
| 94 |
-
final_score
|
| 95 |
|
| 96 |
# ββ Edge case: null action ββββββββββββββββββββββββββββββββββββ
|
| 97 |
if action is None or action.payload is None:
|
|
@@ -100,105 +159,155 @@ def compute_reward(
|
|
| 100 |
breakdown={"invalid_action": 0.001},
|
| 101 |
feedback="Invalid or null action received."
|
| 102 |
)
|
| 103 |
-
|
|
|
|
| 104 |
action_type_enum = action.action_type
|
| 105 |
|
| 106 |
-
# ββ 1. Step reward
|
| 107 |
step_reward = STEP_REWARDS.get(action_type_enum, 0.05)
|
| 108 |
breakdown["step_reward"] = round(step_reward, 4)
|
| 109 |
final_score += step_reward
|
| 110 |
if step_reward > 0:
|
| 111 |
-
feedback_parts.append(f"Action '{action_type_val}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
# ββ
|
| 114 |
grader_score = 0.0
|
| 115 |
-
is_terminal = action_type_enum in
|
| 116 |
|
| 117 |
-
if is_terminal:
|
| 118 |
raw_score, grader_breakdown, grader_feedback = grade(action, task_id)
|
| 119 |
grader_score = raw_score
|
| 120 |
-
breakdown["grader_score"]
|
| 121 |
breakdown["grader_breakdown"] = grader_breakdown
|
| 122 |
final_score += grader_score
|
| 123 |
feedback_parts.append(grader_feedback)
|
| 124 |
|
| 125 |
-
# Efficiency bonus β only on correct terminal action
|
| 126 |
if grader_score >= 0.5:
|
| 127 |
-
eff_bonus = _efficiency_bonus(step_count,
|
| 128 |
if eff_bonus > 0:
|
| 129 |
final_score += eff_bonus
|
| 130 |
breakdown["efficiency_bonus"] = round(eff_bonus, 4)
|
| 131 |
-
feedback_parts.append(f"Efficiency bonus +{eff_bonus}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
elif action_type_enum == ActionType.PROPOSE_FIX:
|
| 134 |
-
# Partial grader score for propose_fix β encourages iterative improvement
|
| 135 |
raw_score, grader_breakdown, _ = grade(action, task_id)
|
| 136 |
-
partial = round(raw_score * 0.4, 4)
|
| 137 |
-
grader_score = partial
|
| 138 |
breakdown["partial_grader_score"] = partial
|
| 139 |
final_score += partial
|
| 140 |
-
if partial > 0:
|
| 141 |
-
feedback_parts.append(f"Partial fix credit +{partial}.")
|
| 142 |
|
| 143 |
elif action_type_enum == ActionType.IDENTIFY_ERROR:
|
| 144 |
-
# Small grader check on error identification
|
| 145 |
raw_score, _, _ = grade(action, task_id)
|
| 146 |
-
partial = round(raw_score * 0.2, 4)
|
| 147 |
breakdown["identification_score"] = partial
|
| 148 |
final_score += partial
|
| 149 |
|
| 150 |
-
# ββ
|
| 151 |
if _detect_loop(previous_actions, action_type_val):
|
| 152 |
consecutive = _count_consecutive(previous_actions, action_type_val)
|
| 153 |
-
loop_pen = LOOP_PENALTY * min(consecutive -
|
| 154 |
final_score += loop_pen
|
| 155 |
breakdown["loop_penalty"] = round(loop_pen, 4)
|
| 156 |
-
feedback_parts.append(f"Loop detected ({consecutive}x
|
| 157 |
|
| 158 |
-
# ββ
|
| 159 |
if action_type_enum == ActionType.REQUEST_HINT:
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
final_score = round(max(0.001, min(0.999, final_score)), 4)
|
| 174 |
breakdown["total"] = final_score
|
| 175 |
|
| 176 |
feedback = " ".join(feedback_parts) if feedback_parts else "Step processed."
|
| 177 |
|
| 178 |
-
return Reward(
|
| 179 |
-
score=final_score,
|
| 180 |
-
breakdown=breakdown,
|
| 181 |
-
feedback=feedback
|
| 182 |
-
)
|
| 183 |
|
| 184 |
|
|
|
|
| 185 |
# EPISODE DONE CONDITION
|
|
|
|
| 186 |
|
| 187 |
def is_done(
|
| 188 |
-
action_type:
|
| 189 |
-
step_count:
|
| 190 |
-
grader_score:
|
|
|
|
| 191 |
) -> bool:
|
| 192 |
"""
|
| 193 |
Episode ends when:
|
| 194 |
-
1. Agent submits
|
| 195 |
2. Max steps reached
|
| 196 |
-
3. Perfect score
|
| 197 |
"""
|
| 198 |
-
if action_type in
|
| 199 |
return True
|
| 200 |
if step_count >= MAX_STEPS:
|
| 201 |
return True
|
| 202 |
if grader_score >= 1.0:
|
| 203 |
return True
|
| 204 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from env.models import Action, Reward, DifficultyLevel, ActionType
|
| 2 |
from env.graders import grade
|
| 3 |
|
| 4 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
# CONSTANTS
|
| 6 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
|
| 8 |
+
MAX_STEPS = 50 # Round 2: long-horizon episodes
|
| 9 |
+
HINT_PENALTY = -0.10 # Per hint requested (increased from Round 1)
|
| 10 |
+
LOOP_PENALTY = -0.08 # Same action on same target 2+ times, no improvement
|
| 11 |
+
INVALID_PENALTY = -0.10 # Null / malformed action
|
| 12 |
+
BACKTRACK_PENALTY = -0.05 # Action makes score worse than previous best
|
| 13 |
+
BUDGET_EXHAUSTION_PEN = -0.15 # Reaching max_steps without submitting report
|
| 14 |
+
EFFICIENCY_BONUS = 0.10 # Solved in < 70% of max_steps
|
| 15 |
+
|
| 16 |
+
# Milestone thresholds: {improvement_fraction: bonus_reward}
|
| 17 |
+
MILESTONE_THRESHOLDS = {
|
| 18 |
+
0.25: 0.15, # 25% improvement β +0.15 bonus
|
| 19 |
+
0.50: 0.25, # 50% improvement β +0.25 bonus
|
| 20 |
+
0.75: 0.40, # 75% improvement β +0.40 bonus
|
| 21 |
+
}
|
| 22 |
|
| 23 |
+
# Step rewards for Round 2 actions (dense signal)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
STEP_REWARDS = {
|
| 25 |
+
# ββ Round 2 actions ββββββββββββββββββββββββββ
|
| 26 |
+
ActionType.INSPECT_QUERY: 0.05, # Investigation rewarded
|
| 27 |
+
ActionType.ANALYZE_INDEXES: 0.05, # Investigation rewarded
|
| 28 |
+
ActionType.CREATE_INDEX: 0.10, # Core optimization action
|
| 29 |
+
ActionType.REWRITE_QUERY: 0.15, # High-value rewrite
|
| 30 |
+
ActionType.ADD_COLUMN: 0.08, # Denormalization
|
| 31 |
+
ActionType.DROP_INDEX: 0.05, # Clean up overhead
|
| 32 |
+
ActionType.PARTITION_TABLE: 0.15, # Big structural improvement
|
| 33 |
+
ActionType.ANALYZE_STATS: 0.05, # Maintenance action
|
| 34 |
+
ActionType.SUBMIT_REPORT: 0.00, # Terminal β score comes from grader
|
| 35 |
+
ActionType.REQUEST_HINT: 0.00, # No reward, only penalty
|
| 36 |
+
# ββ Round 1 backward compat ββββββββββββββββββ
|
| 37 |
+
ActionType.IDENTIFY_ERROR: 0.15,
|
| 38 |
+
ActionType.PROPOSE_FIX: 0.25,
|
| 39 |
+
ActionType.SUBMIT_ANSWER: 0.00,
|
| 40 |
+
ActionType.EXPLAIN_ISSUE: 0.10,
|
| 41 |
+
ActionType.OPTIMIZE_QUERY: 0.20,
|
| 42 |
}
|
| 43 |
|
| 44 |
+
# Terminal actions that end the episode
|
| 45 |
+
TERMINAL_ACTIONS = {
|
| 46 |
+
ActionType.SUBMIT_ANSWER,
|
| 47 |
+
ActionType.OPTIMIZE_QUERY,
|
| 48 |
+
ActionType.SUBMIT_REPORT,
|
| 49 |
+
}
|
| 50 |
|
|
|
|
| 51 |
|
| 52 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
# MILESTONE TRACKER
|
| 54 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
|
| 56 |
+
def check_milestones(
|
| 57 |
+
baseline_score: float,
|
| 58 |
+
new_score: float,
|
| 59 |
+
earned: set,
|
| 60 |
+
) -> tuple[float, list[float]]:
|
| 61 |
"""
|
| 62 |
+
Returns (total_bonus, newly_earned_thresholds).
|
| 63 |
+
One-time bonuses β each milestone only paid once per episode.
|
| 64 |
"""
|
| 65 |
+
max_possible = max(1.0, 100.0 - baseline_score)
|
| 66 |
+
improvement = (new_score - baseline_score) / max_possible
|
| 67 |
+
bonus = 0.0
|
| 68 |
+
newly_earned = []
|
| 69 |
+
|
| 70 |
+
for threshold, reward in MILESTONE_THRESHOLDS.items():
|
| 71 |
+
if improvement >= threshold and threshold not in earned:
|
| 72 |
+
bonus += reward
|
| 73 |
+
newly_earned.append(threshold)
|
| 74 |
+
earned.add(threshold)
|
| 75 |
+
|
| 76 |
+
return round(bonus, 4), newly_earned
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
# LOOP DETECTOR
|
| 81 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
|
| 83 |
+
def _detect_loop(previous_actions: list[str], current_action: str) -> bool:
|
| 84 |
+
"""Returns True if agent has done the same action 2+ times in a row."""
|
| 85 |
+
if len(previous_actions) < 1:
|
| 86 |
return False
|
| 87 |
+
last = previous_actions[-1]
|
| 88 |
+
return last == current_action
|
| 89 |
|
| 90 |
|
| 91 |
def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
|
|
|
|
| 92 |
count = 1
|
| 93 |
for a in reversed(previous_actions):
|
| 94 |
if a == current_action:
|
|
|
|
| 98 |
return count
|
| 99 |
|
| 100 |
|
| 101 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
# EFFICIENCY BONUS
|
| 103 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
|
| 105 |
+
def _efficiency_bonus(step_count: int, max_steps: int) -> float:
|
| 106 |
+
"""Bonus if agent finishes in < 70% of budget."""
|
| 107 |
+
threshold = max_steps * 0.70
|
| 108 |
+
if step_count <= threshold:
|
| 109 |
+
ratio = step_count / max(1, max_steps)
|
| 110 |
+
return round(EFFICIENCY_BONUS * (1.0 - ratio), 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return 0.0
|
| 112 |
|
| 113 |
|
| 114 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
# MAIN REWARD FUNCTION
|
| 116 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
|
| 118 |
def compute_reward(
|
| 119 |
action: Action,
|
|
|
|
| 124 |
hints_used: int,
|
| 125 |
estimated_steps: int,
|
| 126 |
action_counts: dict[str, int],
|
| 127 |
+
# Round 2 extras (optional β backward compatible)
|
| 128 |
+
db_delta: float = 0.0, # Performance score delta from DatabaseSimulator
|
| 129 |
+
baseline_score: float = 0.0, # Scenario baseline score
|
| 130 |
+
current_score: float = 0.0, # Current DB performance score
|
| 131 |
+
milestones_earned: set = None, # Set of already-earned milestone thresholds
|
| 132 |
) -> Reward:
|
| 133 |
"""
|
| 134 |
+
Computes dense reward signal for every step.
|
| 135 |
+
|
| 136 |
+
Components:
|
| 137 |
+
1. Step reward β small reward for valid action type
|
| 138 |
+
2. Delta reward β proportional to DB performance improvement (Round 2)
|
| 139 |
+
3. Milestone bonus β one-time bonus at 25%/50%/75% improvement
|
| 140 |
+
4. Grader score β full score on terminal actions (Round 1 compat)
|
| 141 |
+
5. Loop penalty β repeated same action with no improvement
|
| 142 |
+
6. Hint penalty β cost per hint
|
| 143 |
+
7. Backtrack penalty β action made things worse
|
| 144 |
+
8. Budget penalty β approaching max_steps without submitting
|
| 145 |
+
9. Efficiency bonus β solved fast
|
| 146 |
"""
|
| 147 |
|
| 148 |
+
if milestones_earned is None:
|
| 149 |
+
milestones_earned = set()
|
| 150 |
+
|
| 151 |
+
breakdown = {}
|
| 152 |
feedback_parts = []
|
| 153 |
+
final_score = 0.0
|
| 154 |
|
| 155 |
# ββ Edge case: null action ββββββββββββββββββββββββββββββββββββ
|
| 156 |
if action is None or action.payload is None:
|
|
|
|
| 159 |
breakdown={"invalid_action": 0.001},
|
| 160 |
feedback="Invalid or null action received."
|
| 161 |
)
|
| 162 |
+
|
| 163 |
+
action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
|
| 164 |
action_type_enum = action.action_type
|
| 165 |
|
| 166 |
+
# ββ 1. Step reward ββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
step_reward = STEP_REWARDS.get(action_type_enum, 0.05)
|
| 168 |
breakdown["step_reward"] = round(step_reward, 4)
|
| 169 |
final_score += step_reward
|
| 170 |
if step_reward > 0:
|
| 171 |
+
feedback_parts.append(f"Action '{action_type_val}' +{step_reward}.")
|
| 172 |
+
|
| 173 |
+
# ββ 2. Delta reward (Round 2 DB performance change) βββββββββββ
|
| 174 |
+
if db_delta != 0.0:
|
| 175 |
+
delta_reward = round((db_delta / 100.0) * 0.40, 4)
|
| 176 |
+
delta_reward = max(-0.40, min(0.40, delta_reward))
|
| 177 |
+
breakdown["delta_reward"] = delta_reward
|
| 178 |
+
final_score += delta_reward
|
| 179 |
+
if delta_reward > 0:
|
| 180 |
+
feedback_parts.append(f"DB improved +{db_delta:.1f} pts. Delta reward +{delta_reward}.")
|
| 181 |
+
elif delta_reward < 0:
|
| 182 |
+
feedback_parts.append(f"DB worsened {db_delta:.1f} pts. Penalty {delta_reward}.")
|
| 183 |
+
|
| 184 |
+
# ββ 3. Milestone bonuses ββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
if baseline_score > 0 and current_score > 0:
|
| 186 |
+
milestone_bonus, newly_earned = check_milestones(
|
| 187 |
+
baseline_score, current_score, milestones_earned
|
| 188 |
+
)
|
| 189 |
+
if milestone_bonus > 0:
|
| 190 |
+
breakdown["milestone_bonus"] = milestone_bonus
|
| 191 |
+
final_score += milestone_bonus
|
| 192 |
+
pct = int(max(newly_earned) * 100)
|
| 193 |
+
feedback_parts.append(f"π― Milestone! {pct}% improvement. Bonus +{milestone_bonus}!")
|
| 194 |
|
| 195 |
+
# ββ 4. Grader score for terminal actions (Round 1 compat) βββββ
|
| 196 |
grader_score = 0.0
|
| 197 |
+
is_terminal = action_type_enum in TERMINAL_ACTIONS
|
| 198 |
|
| 199 |
+
if is_terminal and action_type_enum != ActionType.SUBMIT_REPORT:
|
| 200 |
raw_score, grader_breakdown, grader_feedback = grade(action, task_id)
|
| 201 |
grader_score = raw_score
|
| 202 |
+
breakdown["grader_score"] = round(grader_score, 4)
|
| 203 |
breakdown["grader_breakdown"] = grader_breakdown
|
| 204 |
final_score += grader_score
|
| 205 |
feedback_parts.append(grader_feedback)
|
| 206 |
|
|
|
|
| 207 |
if grader_score >= 0.5:
|
| 208 |
+
eff_bonus = _efficiency_bonus(step_count, MAX_STEPS)
|
| 209 |
if eff_bonus > 0:
|
| 210 |
final_score += eff_bonus
|
| 211 |
breakdown["efficiency_bonus"] = round(eff_bonus, 4)
|
| 212 |
+
feedback_parts.append(f"Efficiency bonus +{eff_bonus}.")
|
| 213 |
+
|
| 214 |
+
elif is_terminal and action_type_enum == ActionType.SUBMIT_REPORT:
|
| 215 |
+
# Round 2 terminal: compute from DB performance
|
| 216 |
+
if baseline_score > 0 and current_score > 0:
|
| 217 |
+
perf_improvement = (current_score - baseline_score) / max(1.0, 100.0 - baseline_score)
|
| 218 |
+
step_efficiency = 1.0 - (step_count / max(1, MAX_STEPS))
|
| 219 |
+
terminal_score = round(
|
| 220 |
+
(perf_improvement * 0.60) + (step_efficiency * 0.20) + 0.10, 4
|
| 221 |
+
)
|
| 222 |
+
terminal_score = max(0.001, min(0.999, terminal_score))
|
| 223 |
+
breakdown["terminal_score"] = terminal_score
|
| 224 |
+
breakdown["perf_improvement"] = round(perf_improvement, 4)
|
| 225 |
+
breakdown["step_efficiency"] = round(step_efficiency, 4)
|
| 226 |
+
final_score += terminal_score
|
| 227 |
+
feedback_parts.append(
|
| 228 |
+
f"Report submitted. Performance: {baseline_score:.1f} β {current_score:.1f}. "
|
| 229 |
+
f"Terminal score: {terminal_score}."
|
| 230 |
+
)
|
| 231 |
+
# Efficiency bonus on submit_report too
|
| 232 |
+
eff_bonus = _efficiency_bonus(step_count, MAX_STEPS)
|
| 233 |
+
if eff_bonus > 0:
|
| 234 |
+
final_score += eff_bonus
|
| 235 |
+
breakdown["efficiency_bonus"] = round(eff_bonus, 4)
|
| 236 |
+
feedback_parts.append(f"Efficiency bonus +{eff_bonus}.")
|
| 237 |
+
else:
|
| 238 |
+
breakdown["terminal_score"] = 0.10
|
| 239 |
+
final_score += 0.10
|
| 240 |
+
feedback_parts.append("Report submitted.")
|
| 241 |
|
| 242 |
elif action_type_enum == ActionType.PROPOSE_FIX:
|
|
|
|
| 243 |
raw_score, grader_breakdown, _ = grade(action, task_id)
|
| 244 |
+
partial = round(raw_score * 0.4, 4)
|
|
|
|
| 245 |
breakdown["partial_grader_score"] = partial
|
| 246 |
final_score += partial
|
|
|
|
|
|
|
| 247 |
|
| 248 |
elif action_type_enum == ActionType.IDENTIFY_ERROR:
|
|
|
|
| 249 |
raw_score, _, _ = grade(action, task_id)
|
| 250 |
+
partial = round(raw_score * 0.2, 4)
|
| 251 |
breakdown["identification_score"] = partial
|
| 252 |
final_score += partial
|
| 253 |
|
| 254 |
+
# ββ 5. Loop penalty βββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
if _detect_loop(previous_actions, action_type_val):
|
| 256 |
consecutive = _count_consecutive(previous_actions, action_type_val)
|
| 257 |
+
loop_pen = LOOP_PENALTY * min(consecutive - 1, 3)
|
| 258 |
final_score += loop_pen
|
| 259 |
breakdown["loop_penalty"] = round(loop_pen, 4)
|
| 260 |
+
feedback_parts.append(f"Loop detected ({consecutive}x). Penalty {loop_pen}.")
|
| 261 |
|
| 262 |
+
# ββ 6. Hint penalty βββββββββββββββββββββββββββββββββββββββββββ
|
| 263 |
if action_type_enum == ActionType.REQUEST_HINT:
|
| 264 |
+
final_score += HINT_PENALTY
|
| 265 |
+
breakdown["hint_penalty"] = HINT_PENALTY
|
| 266 |
+
feedback_parts.append(f"Hint requested. Penalty {HINT_PENALTY}.")
|
| 267 |
+
|
| 268 |
+
# ββ 7. Backtrack penalty ββββββββββββββββββββββββββββββββββββββ
|
| 269 |
+
if db_delta < -1.0:
|
| 270 |
+
final_score += BACKTRACK_PENALTY
|
| 271 |
+
breakdown["backtrack_penalty"] = BACKTRACK_PENALTY
|
| 272 |
+
feedback_parts.append(f"Performance regressed. Backtrack penalty {BACKTRACK_PENALTY}.")
|
| 273 |
+
|
| 274 |
+
# ββ 8. Budget exhaustion penalty βββββββββββββββββββββββββββββ
|
| 275 |
+
if step_count >= MAX_STEPS - 2 and not is_terminal:
|
| 276 |
+
final_score += BUDGET_EXHAUSTION_PEN
|
| 277 |
+
breakdown["budget_penalty"] = BUDGET_EXHAUSTION_PEN
|
| 278 |
+
feedback_parts.append("Budget nearly exhausted. Submit report now!")
|
| 279 |
+
|
| 280 |
+
# ββ Clamp to (0.001, 0.999) βββββββββββββββββββββββββββββββββββ
|
| 281 |
final_score = round(max(0.001, min(0.999, final_score)), 4)
|
| 282 |
breakdown["total"] = final_score
|
| 283 |
|
| 284 |
feedback = " ".join(feedback_parts) if feedback_parts else "Step processed."
|
| 285 |
|
| 286 |
+
return Reward(score=final_score, breakdown=breakdown, feedback=feedback)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
|
| 289 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
# EPISODE DONE CONDITION
|
| 291 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 292 |
|
| 293 |
def is_done(
|
| 294 |
+
action_type: ActionType,
|
| 295 |
+
step_count: int,
|
| 296 |
+
grader_score: float = 0.0,
|
| 297 |
+
target_reached: bool = False,
|
| 298 |
) -> bool:
|
| 299 |
"""
|
| 300 |
Episode ends when:
|
| 301 |
+
1. Agent submits report / final answer
|
| 302 |
2. Max steps reached
|
| 303 |
+
3. Perfect score / target reached
|
| 304 |
"""
|
| 305 |
+
if action_type in TERMINAL_ACTIONS:
|
| 306 |
return True
|
| 307 |
if step_count >= MAX_STEPS:
|
| 308 |
return True
|
| 309 |
if grader_score >= 1.0:
|
| 310 |
return True
|
| 311 |
+
if target_reached:
|
| 312 |
+
return True
|
| 313 |
+
return False
|
env/tasks.py
CHANGED
|
@@ -3,28 +3,50 @@ import random
|
|
| 3 |
from pathlib import Path
|
| 4 |
from env.models import DifficultyLevel, TaskInfo
|
| 5 |
|
| 6 |
-
#
|
|
|
|
|
|
|
| 7 |
|
| 8 |
BASE_DIR = Path(__file__).parent.parent / "dataset"
|
| 9 |
|
|
|
|
| 10 |
def _load(filename: str) -> list[dict]:
|
| 11 |
path = BASE_DIR / filename
|
| 12 |
with open(path, "r", encoding="utf-8") as f:
|
| 13 |
return json.load(f)
|
| 14 |
|
|
|
|
|
|
|
| 15 |
EASY_CASES = _load("easy_cases.json")
|
| 16 |
MEDIUM_CASES = _load("medium_cases.json")
|
| 17 |
HARD_CASES = _load("hard_cases.json")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
ALL_CASES: dict[str, list[dict]] = {
|
| 20 |
-
DifficultyLevel.EASY: EASY_CASES,
|
| 21 |
-
DifficultyLevel.MEDIUM: MEDIUM_CASES,
|
| 22 |
-
DifficultyLevel.HARD: HARD_CASES,
|
| 23 |
}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# ACTION SCHEMA (required by /tasks validator)
|
|
|
|
| 26 |
|
| 27 |
ACTION_SCHEMA = {
|
|
|
|
| 28 |
"identify_error": {
|
| 29 |
"description": "Identify where and what the error is without fixing it yet",
|
| 30 |
"payload_fields": {
|
|
@@ -36,52 +58,120 @@ ACTION_SCHEMA = {
|
|
| 36 |
"propose_fix": {
|
| 37 |
"description": "Propose a fix without submitting as final answer",
|
| 38 |
"payload_fields": {
|
| 39 |
-
"fixed_query":
|
| 40 |
-
"change_made":
|
| 41 |
-
"confidence":
|
| 42 |
}
|
| 43 |
},
|
| 44 |
"submit_answer": {
|
| 45 |
"description": "Submit the final fixed query as the definitive answer",
|
| 46 |
"payload_fields": {
|
| 47 |
-
"fixed_query":
|
| 48 |
-
"explanation":
|
| 49 |
-
"error_type":
|
| 50 |
-
"confidence":
|
| 51 |
}
|
| 52 |
},
|
| 53 |
"request_hint": {
|
| 54 |
-
"description": "Request a hint β costs 0.
|
| 55 |
"payload_fields": {
|
| 56 |
-
"hint_type": {"type": "string", "required": False, "description": "
|
| 57 |
}
|
| 58 |
},
|
| 59 |
"explain_issue": {
|
| 60 |
-
"description": "Explain the issue in detail
|
| 61 |
"payload_fields": {
|
| 62 |
-
"explanation":
|
| 63 |
-
"impact":
|
| 64 |
-
"root_cause":
|
| 65 |
}
|
| 66 |
},
|
| 67 |
"optimize_query": {
|
| 68 |
-
"description": "Submit an optimized version of the query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"payload_fields": {
|
| 70 |
-
"
|
| 71 |
-
"optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"},
|
| 72 |
-
"expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain description"},
|
| 73 |
-
"explanation": {"type": "string", "required": False, "description": "Why this optimization works"},
|
| 74 |
-
"confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
|
| 75 |
}
|
| 76 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
}
|
| 78 |
-
# TASK MANAGER
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
class TaskManager:
|
| 82 |
"""
|
| 83 |
-
Manages task selection
|
| 84 |
-
|
|
|
|
| 85 |
"""
|
| 86 |
|
| 87 |
def __init__(self):
|
|
@@ -89,9 +179,8 @@ class TaskManager:
|
|
| 89 |
|
| 90 |
def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict:
|
| 91 |
"""
|
| 92 |
-
Returns a task
|
| 93 |
-
|
| 94 |
-
Otherwise picks randomly, avoiding recently used tasks.
|
| 95 |
"""
|
| 96 |
pool = ALL_CASES[difficulty]
|
| 97 |
|
|
@@ -101,7 +190,7 @@ class TaskManager:
|
|
| 101 |
return case
|
| 102 |
raise ValueError(f"Task '{task_id}' not found in {difficulty} pool")
|
| 103 |
|
| 104 |
-
# Avoid
|
| 105 |
available = [c for c in pool if c["id"] not in self._used_ids]
|
| 106 |
if not available:
|
| 107 |
self._used_ids.clear()
|
|
@@ -112,66 +201,92 @@ class TaskManager:
|
|
| 112 |
return task
|
| 113 |
|
| 114 |
def get_random_task(self) -> dict:
|
| 115 |
-
"""Pick a random task from any difficulty."""
|
| 116 |
difficulty = random.choice(list(DifficultyLevel))
|
| 117 |
return self.get_task(difficulty)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def build_observation_context(self, task: dict) -> dict:
|
| 120 |
"""
|
| 121 |
-
|
| 122 |
-
|
|
|
|
| 123 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
context = {
|
| 125 |
-
"buggy_query":
|
| 126 |
-
"error_message":
|
| 127 |
-
"database_schema":
|
| 128 |
-
"error_type_hint":
|
| 129 |
-
"category":
|
| 130 |
-
"estimated_steps":
|
| 131 |
}
|
| 132 |
-
|
| 133 |
-
# For performance tasks include extra context
|
| 134 |
if task.get("performance_issue"):
|
| 135 |
context["performance_issue"] = {
|
| 136 |
"type": task["performance_issue"]["type"],
|
| 137 |
"impact": task["performance_issue"]["impact"],
|
| 138 |
-
# Do NOT include timing numbers β agent must figure it out
|
| 139 |
}
|
| 140 |
-
|
| 141 |
-
# Include expected output shape (but not the fixed query!)
|
| 142 |
if task.get("expected_output") and isinstance(task["expected_output"], list):
|
| 143 |
context["expected_output_sample"] = task["expected_output"][:1]
|
| 144 |
-
|
| 145 |
return context
|
| 146 |
|
| 147 |
def get_hint(self, task: dict, hint_number: int) -> str:
|
| 148 |
-
"""
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
idx = min(hint_number - 1, len(hints) - 1)
|
| 158 |
-
return hints[idx]
|
| 159 |
|
| 160 |
def list_all_tasks(self) -> list[TaskInfo]:
|
| 161 |
-
"""Returns TaskInfo list for the /tasks endpoint."""
|
| 162 |
result = []
|
| 163 |
for difficulty, cases in ALL_CASES.items():
|
| 164 |
for case in cases:
|
| 165 |
result.append(TaskInfo(
|
| 166 |
-
id=case["id"],
|
| 167 |
-
difficulty=difficulty,
|
| 168 |
-
description=case
|
| 169 |
-
action_schema=ACTION_SCHEMA
|
| 170 |
))
|
| 171 |
return result
|
| 172 |
|
| 173 |
def get_ground_truth(self, task_id: str) -> dict | None:
|
| 174 |
-
"""Returns
|
| 175 |
for cases in ALL_CASES.values():
|
| 176 |
for case in cases:
|
| 177 |
if case["id"] == task_id:
|
|
@@ -180,4 +295,4 @@ class TaskManager:
|
|
| 180 |
|
| 181 |
|
| 182 |
# Singleton instance
|
| 183 |
-
task_manager = TaskManager()
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from env.models import DifficultyLevel, TaskInfo
|
| 5 |
|
| 6 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
# LOAD DATASETS β Round 1 + Round 2
|
| 8 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
|
| 10 |
BASE_DIR = Path(__file__).parent.parent / "dataset"
|
| 11 |
|
| 12 |
+
|
| 13 |
def _load(filename: str) -> list[dict]:
|
| 14 |
path = BASE_DIR / filename
|
| 15 |
with open(path, "r", encoding="utf-8") as f:
|
| 16 |
return json.load(f)
|
| 17 |
|
| 18 |
+
|
| 19 |
+
# Round 1 cases (keep for backward compatibility)
|
| 20 |
EASY_CASES = _load("easy_cases.json")
|
| 21 |
MEDIUM_CASES = _load("medium_cases.json")
|
| 22 |
HARD_CASES = _load("hard_cases.json")
|
| 23 |
|
| 24 |
+
# Round 2 scenarios (new long-horizon DB engineering tasks)
|
| 25 |
+
EASY_SCENARIOS = _load("easy_scenarios.json")
|
| 26 |
+
MEDIUM_SCENARIOS = _load("medium_scenarios.json")
|
| 27 |
+
HARD_SCENARIOS = _load("hard_scenarios.json")
|
| 28 |
+
|
| 29 |
+
# Combined pools β Round 2 scenarios take priority (listed first)
|
| 30 |
ALL_CASES: dict[str, list[dict]] = {
|
| 31 |
+
DifficultyLevel.EASY: EASY_SCENARIOS + EASY_CASES,
|
| 32 |
+
DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS + MEDIUM_CASES,
|
| 33 |
+
DifficultyLevel.HARD: HARD_SCENARIOS + HARD_CASES,
|
| 34 |
}
|
| 35 |
|
| 36 |
+
# Round 2 only (for training pipeline)
|
| 37 |
+
SCENARIO_ONLY: dict[str, list[dict]] = {
|
| 38 |
+
DifficultyLevel.EASY: EASY_SCENARIOS,
|
| 39 |
+
DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS,
|
| 40 |
+
DifficultyLevel.HARD: HARD_SCENARIOS,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
# ACTION SCHEMA (required by /tasks validator)
|
| 46 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
|
| 48 |
ACTION_SCHEMA = {
|
| 49 |
+
# ββ Round 1 actions ββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
"identify_error": {
|
| 51 |
"description": "Identify where and what the error is without fixing it yet",
|
| 52 |
"payload_fields": {
|
|
|
|
| 58 |
"propose_fix": {
|
| 59 |
"description": "Propose a fix without submitting as final answer",
|
| 60 |
"payload_fields": {
|
| 61 |
+
"fixed_query": {"type": "string", "required": True, "description": "The proposed corrected SQL query"},
|
| 62 |
+
"change_made": {"type": "string", "required": True, "description": "What specifically was changed"},
|
| 63 |
+
"confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"}
|
| 64 |
}
|
| 65 |
},
|
| 66 |
"submit_answer": {
|
| 67 |
"description": "Submit the final fixed query as the definitive answer",
|
| 68 |
"payload_fields": {
|
| 69 |
+
"fixed_query": {"type": "string", "required": True, "description": "Final corrected SQL query"},
|
| 70 |
+
"explanation": {"type": "string", "required": True, "description": "Full explanation of fix"},
|
| 71 |
+
"error_type": {"type": "string", "required": False, "description": "syntax | logic | performance"},
|
| 72 |
+
"confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
|
| 73 |
}
|
| 74 |
},
|
| 75 |
"request_hint": {
|
| 76 |
+
"description": "Request a hint β costs 0.10 reward penalty per hint",
|
| 77 |
"payload_fields": {
|
| 78 |
+
"hint_type": {"type": "string", "required": False, "description": "location | error_type | fix_direction"}
|
| 79 |
}
|
| 80 |
},
|
| 81 |
"explain_issue": {
|
| 82 |
+
"description": "Explain the issue in detail",
|
| 83 |
"payload_fields": {
|
| 84 |
+
"explanation": {"type": "string", "required": True, "description": "Detailed explanation"},
|
| 85 |
+
"impact": {"type": "string", "required": False, "description": "Impact on query performance"},
|
| 86 |
+
"root_cause": {"type": "string", "required": False, "description": "Root cause analysis"}
|
| 87 |
}
|
| 88 |
},
|
| 89 |
"optimize_query": {
|
| 90 |
+
"description": "Submit an optimized version of the query",
|
| 91 |
+
"payload_fields": {
|
| 92 |
+
"optimized_query": {"type": "string", "required": True, "description": "Optimized SQL"},
|
| 93 |
+
"optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"},
|
| 94 |
+
"expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain"},
|
| 95 |
+
"explanation": {"type": "string", "required": False, "description": "Why this optimization works"},
|
| 96 |
+
"confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
# ββ Round 2 actions ββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
"inspect_query": {
|
| 101 |
+
"description": "EXPLAIN a slow query β reveals scan type, rows examined, index usage",
|
| 102 |
+
"payload_fields": {
|
| 103 |
+
"query_id": {"type": "string", "required": True, "description": "ID of slow query to inspect (e.g. 'q1')"}
|
| 104 |
+
}
|
| 105 |
+
},
|
| 106 |
+
"analyze_indexes": {
|
| 107 |
+
"description": "Show all indexes on a table + usage frequency + missing index hints",
|
| 108 |
"payload_fields": {
|
| 109 |
+
"table": {"type": "string", "required": True, "description": "Table name to analyze"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
}
|
| 111 |
+
},
|
| 112 |
+
"create_index": {
|
| 113 |
+
"description": "Add a composite index on specified columns β core optimization action",
|
| 114 |
+
"payload_fields": {
|
| 115 |
+
"table": {"type": "string", "required": True, "description": "Table to index"},
|
| 116 |
+
"columns": {"type": "list|string", "required": True, "description": "Columns to index (list or comma-separated string)"}
|
| 117 |
+
}
|
| 118 |
+
},
|
| 119 |
+
"rewrite_query": {
|
| 120 |
+
"description": "Submit a rewritten SQL query β system evaluates execution time improvement",
|
| 121 |
+
"payload_fields": {
|
| 122 |
+
"query_id": {"type": "string", "required": True, "description": "ID of query to rewrite"},
|
| 123 |
+
"new_sql": {"type": "string", "required": True, "description": "Rewritten SQL query"}
|
| 124 |
+
}
|
| 125 |
+
},
|
| 126 |
+
"add_column": {
|
| 127 |
+
"description": "Add a denormalization column to reduce expensive JOINs",
|
| 128 |
+
"payload_fields": {
|
| 129 |
+
"table": {"type": "string", "required": True, "description": "Table to modify"},
|
| 130 |
+
"column": {"type": "string", "required": True, "description": "New column name"},
|
| 131 |
+
"purpose": {"type": "string", "required": False, "description": "Why this column helps"}
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
"drop_index": {
|
| 135 |
+
"description": "Remove an unused index to reduce write overhead",
|
| 136 |
+
"payload_fields": {
|
| 137 |
+
"table": {"type": "string", "required": True, "description": "Table name"},
|
| 138 |
+
"index_name": {"type": "string", "required": True, "description": "Index name to drop (cannot drop PRIMARY)"}
|
| 139 |
+
}
|
| 140 |
+
},
|
| 141 |
+
"partition_table": {
|
| 142 |
+
"description": "Partition a large table by date or ID range for range query efficiency",
|
| 143 |
+
"payload_fields": {
|
| 144 |
+
"table": {"type": "string", "required": True, "description": "Table to partition"},
|
| 145 |
+
"partition_by": {"type": "string", "required": False, "description": "Column to partition on (e.g. 'created_at')"},
|
| 146 |
+
"partition_type": {"type": "string", "required": False, "description": "RANGE | LIST | HASH"}
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
"analyze_statistics": {
|
| 150 |
+
"description": "Update table statistics for query planner accuracy",
|
| 151 |
+
"payload_fields": {
|
| 152 |
+
"table": {"type": "string", "required": True, "description": "Table to analyze"}
|
| 153 |
+
}
|
| 154 |
+
},
|
| 155 |
+
"submit_report": {
|
| 156 |
+
"description": "TERMINAL: Submit final optimization report β ends episode, computes full score",
|
| 157 |
+
"payload_fields": {
|
| 158 |
+
"summary": {"type": "string", "required": True, "description": "Summary of optimizations applied"},
|
| 159 |
+
"actions_taken": {"type": "list", "required": False, "description": "List of key actions taken"},
|
| 160 |
+
"expected_gain": {"type": "string", "required": False, "description": "Expected performance improvement"}
|
| 161 |
+
}
|
| 162 |
+
},
|
| 163 |
}
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
# TASK MANAGER
|
| 168 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
|
| 170 |
class TaskManager:
|
| 171 |
"""
|
| 172 |
+
Manages task selection for both Round 1 and Round 2 scenarios.
|
| 173 |
+
Round 2 scenarios have tables/slow_queries structure.
|
| 174 |
+
Round 1 cases have buggy_query structure.
|
| 175 |
"""
|
| 176 |
|
| 177 |
def __init__(self):
|
|
|
|
| 179 |
|
| 180 |
def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict:
|
| 181 |
"""
|
| 182 |
+
Returns a task for the given difficulty.
|
| 183 |
+
Prefers Round 2 scenarios, falls back to Round 1 cases.
|
|
|
|
| 184 |
"""
|
| 185 |
pool = ALL_CASES[difficulty]
|
| 186 |
|
|
|
|
| 190 |
return case
|
| 191 |
raise ValueError(f"Task '{task_id}' not found in {difficulty} pool")
|
| 192 |
|
| 193 |
+
# Avoid recently used tasks
|
| 194 |
available = [c for c in pool if c["id"] not in self._used_ids]
|
| 195 |
if not available:
|
| 196 |
self._used_ids.clear()
|
|
|
|
| 201 |
return task
|
| 202 |
|
| 203 |
def get_random_task(self) -> dict:
|
|
|
|
| 204 |
difficulty = random.choice(list(DifficultyLevel))
|
| 205 |
return self.get_task(difficulty)
|
| 206 |
|
| 207 |
+
def get_scenario(self, difficulty: DifficultyLevel, scenario_id: str | None = None) -> dict:
|
| 208 |
+
"""Get Round 2 scenario specifically."""
|
| 209 |
+
pool = SCENARIO_ONLY[difficulty]
|
| 210 |
+
if scenario_id:
|
| 211 |
+
for s in pool:
|
| 212 |
+
if s["id"] == scenario_id:
|
| 213 |
+
return s
|
| 214 |
+
raise ValueError(f"Scenario '{scenario_id}' not found")
|
| 215 |
+
return random.choice(pool)
|
| 216 |
+
|
| 217 |
def build_observation_context(self, task: dict) -> dict:
|
| 218 |
"""
|
| 219 |
+
Builds current_context for the Observation.
|
| 220 |
+
Handles both Round 2 scenario format and Round 1 case format.
|
| 221 |
+
CRITICAL: Never leaks ground truth (fixed_query / optimal_actions).
|
| 222 |
"""
|
| 223 |
+
# ββ Round 2 scenario format βββββββββββββββββββββββββββββββ
|
| 224 |
+
if "slow_queries" in task:
|
| 225 |
+
return {
|
| 226 |
+
"scenario_id": task["id"],
|
| 227 |
+
"description": task.get("description", ""),
|
| 228 |
+
"tables": task.get("tables", []),
|
| 229 |
+
"slow_queries": task.get("slow_queries", []),
|
| 230 |
+
"performance_score_baseline": task.get("performance_score_baseline", 0.0),
|
| 231 |
+
"target_score": task.get("target_score", 85.0),
|
| 232 |
+
"max_steps": task.get("max_steps", 50),
|
| 233 |
+
"category": task.get("category", ""),
|
| 234 |
+
# Do NOT include missing_index_hints (that's the answer)
|
| 235 |
+
# Do NOT include optimal_actions (that's the answer)
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
# ββ Round 1 case format (backward compatible) ββββββββββββ
|
| 239 |
context = {
|
| 240 |
+
"buggy_query": task.get("buggy_query", ""),
|
| 241 |
+
"error_message": task.get("error_message", ""),
|
| 242 |
+
"database_schema": task.get("database_schema", ""),
|
| 243 |
+
"error_type_hint": task.get("error_type", ""),
|
| 244 |
+
"category": task.get("category", ""),
|
| 245 |
+
"estimated_steps": task.get("estimated_fix_steps", 5),
|
| 246 |
}
|
|
|
|
|
|
|
| 247 |
if task.get("performance_issue"):
|
| 248 |
context["performance_issue"] = {
|
| 249 |
"type": task["performance_issue"]["type"],
|
| 250 |
"impact": task["performance_issue"]["impact"],
|
|
|
|
| 251 |
}
|
|
|
|
|
|
|
| 252 |
if task.get("expected_output") and isinstance(task["expected_output"], list):
|
| 253 |
context["expected_output_sample"] = task["expected_output"][:1]
|
|
|
|
| 254 |
return context
|
| 255 |
|
| 256 |
def get_hint(self, task: dict, hint_number: int) -> str:
|
| 257 |
+
"""Progressive hints. Each hint reveals more info. Costs -0.10 each."""
|
| 258 |
+
# Round 2 scenario hints
|
| 259 |
+
if "slow_queries" in task:
|
| 260 |
+
hints = [
|
| 261 |
+
f"Hint 1: Start by inspecting your slow queries with inspect_query action.",
|
| 262 |
+
f"Hint 2: Use analyze_indexes on tables appearing in slow queries.",
|
| 263 |
+
f"Hint 3: Category is '{task.get('category', 'indexing')}'. Target score: {task.get('target_score', 85.0)}.",
|
| 264 |
+
]
|
| 265 |
+
else:
|
| 266 |
+
# Round 1 hints
|
| 267 |
+
hints = [
|
| 268 |
+
f"Hint 1: The error is in the {task.get('error_location', 'query')}.",
|
| 269 |
+
f"Hint 2: This is a {task.get('error_type', 'unknown')} error. Category: {task.get('category')}.",
|
| 270 |
+
f"Hint 3: Fix: {task.get('fix_description', 'Review the query carefully.')}",
|
| 271 |
+
]
|
| 272 |
idx = min(hint_number - 1, len(hints) - 1)
|
| 273 |
+
return hints[max(0, idx)]
|
| 274 |
|
| 275 |
def list_all_tasks(self) -> list[TaskInfo]:
|
| 276 |
+
"""Returns TaskInfo list for the /tasks endpoint β all 30 tasks."""
|
| 277 |
result = []
|
| 278 |
for difficulty, cases in ALL_CASES.items():
|
| 279 |
for case in cases:
|
| 280 |
result.append(TaskInfo(
|
| 281 |
+
id = case["id"],
|
| 282 |
+
difficulty = difficulty,
|
| 283 |
+
description = case.get("description", ""),
|
| 284 |
+
action_schema = ACTION_SCHEMA
|
| 285 |
))
|
| 286 |
return result
|
| 287 |
|
| 288 |
def get_ground_truth(self, task_id: str) -> dict | None:
|
| 289 |
+
"""Returns full task including ground truth (used by grader only)."""
|
| 290 |
for cases in ALL_CASES.values():
|
| 291 |
for case in cases:
|
| 292 |
if case["id"] == task_id:
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
# Singleton instance
|
| 298 |
+
task_manager = TaskManager()
|
tests/test_environment.py
CHANGED
|
@@ -22,8 +22,7 @@ def test_reset_easy(env):
|
|
| 22 |
assert obs.step_count == 0
|
| 23 |
assert obs.difficulty == DifficultyLevel.EASY
|
| 24 |
assert "fixed_query" not in obs.current_context
|
| 25 |
-
assert "buggy_query" in obs.current_context
|
| 26 |
-
|
| 27 |
|
| 28 |
def test_reset_medium(env):
|
| 29 |
obs = env.reset(difficulty="medium")
|
|
@@ -65,7 +64,7 @@ def test_step_null_action(env):
|
|
| 65 |
"""Null action must return -0.1, never crash."""
|
| 66 |
env.reset(difficulty="easy")
|
| 67 |
resp = env.step(None)
|
| 68 |
-
assert resp.reward.score =
|
| 69 |
assert resp.done == False
|
| 70 |
|
| 71 |
|
|
@@ -110,7 +109,7 @@ def test_max_steps(env):
|
|
| 110 |
action = Action(action_type=ActionType.IDENTIFY_ERROR,
|
| 111 |
payload={"error_location": "x", "error_type": "syntax"})
|
| 112 |
done = False
|
| 113 |
-
for _ in range(
|
| 114 |
resp = env.step(action)
|
| 115 |
if resp.done:
|
| 116 |
done = True
|
|
|
|
| 22 |
assert obs.step_count == 0
|
| 23 |
assert obs.difficulty == DifficultyLevel.EASY
|
| 24 |
assert "fixed_query" not in obs.current_context
|
| 25 |
+
assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context
|
|
|
|
| 26 |
|
| 27 |
def test_reset_medium(env):
|
| 28 |
obs = env.reset(difficulty="medium")
|
|
|
|
| 64 |
"""Null action must return -0.1, never crash."""
|
| 65 |
env.reset(difficulty="easy")
|
| 66 |
resp = env.step(None)
|
| 67 |
+
assert resp.reward.score >= 0.001
|
| 68 |
assert resp.done == False
|
| 69 |
|
| 70 |
|
|
|
|
| 109 |
action = Action(action_type=ActionType.IDENTIFY_ERROR,
|
| 110 |
payload={"error_location": "x", "error_type": "syntax"})
|
| 111 |
done = False
|
| 112 |
+
for _ in range(55):
|
| 113 |
resp = env.step(action)
|
| 114 |
if resp.done:
|
| 115 |
done = True
|
tests/test_graders.py
CHANGED
|
@@ -21,7 +21,7 @@ def test_easy_perfect_score():
|
|
| 21 |
|
| 22 |
def test_null_action_returns_zero():
|
| 23 |
score, breakdown, feedback = grade(None, "easy_001")
|
| 24 |
-
assert score =
|
| 25 |
assert "null" in feedback.lower() or "no action" in feedback.lower()
|
| 26 |
|
| 27 |
|
|
@@ -29,7 +29,7 @@ def test_unknown_task_returns_zero():
|
|
| 29 |
action = Action(action_type=ActionType.SUBMIT_ANSWER,
|
| 30 |
payload={"fixed_query": "SELECT 1", "explanation": "test"})
|
| 31 |
score, _, _ = grade(action, "nonexistent_task_999")
|
| 32 |
-
assert score =
|
| 33 |
|
| 34 |
|
| 35 |
def test_determinism():
|
|
|
|
| 21 |
|
| 22 |
def test_null_action_returns_zero():
|
| 23 |
score, breakdown, feedback = grade(None, "easy_001")
|
| 24 |
+
assert score <= 0.001 # clamped minimum for OpenEnv compliance
|
| 25 |
assert "null" in feedback.lower() or "no action" in feedback.lower()
|
| 26 |
|
| 27 |
|
|
|
|
| 29 |
action = Action(action_type=ActionType.SUBMIT_ANSWER,
|
| 30 |
payload={"fixed_query": "SELECT 1", "explanation": "test"})
|
| 31 |
score, _, _ = grade(action, "nonexistent_task_999")
|
| 32 |
+
assert score <= 0.001
|
| 33 |
|
| 34 |
|
| 35 |
def test_determinism():
|
training/evaluate_agent.py
ADDED
|
File without changes
|
training/generate_training_data.py
ADDED
|
File without changes
|
training/train_agent.py
ADDED
|
File without changes
|