junaid0600 commited on
Commit
8cb206e
Β·
1 Parent(s): 1eef47f

Round 2: SQL Database Engineer Agent - 24/24 tests passing

Browse files
api/server.py CHANGED
@@ -16,7 +16,7 @@ from env.models import (
16
  StepResponse, ResetResponse, TaskListResponse,
17
  BaselineResponse, BaselineResult,
18
  GraderRequest, GraderResponse,
19
- HealthResponse, TaskInfo
20
  )
21
  from env.tasks import task_manager, ACTION_SCHEMA
22
  from env.graders import grade
@@ -33,18 +33,21 @@ async def lifespan(app: FastAPI):
33
  environment.reset(difficulty="easy")
34
  yield
35
 
 
36
  # ─────────────────────────────────────────────
37
  # APP DEFINITION
38
  # ─────────────────────────────────────────────
39
 
40
  app = FastAPI(
41
- title = "SQL Query Debugger β€” OpenEnv Environment",
42
  description = (
43
  "An OpenEnv-compliant reinforcement learning environment where AI agents "
44
- "learn to debug SQL queries across syntax errors, logic bugs, and performance issues. "
45
- "Built for the META x PyTorch x SST OpenEnv Hackathon."
 
 
46
  ),
47
- version = "1.0.0",
48
  lifespan = lifespan,
49
  docs_url = "/docs",
50
  redoc_url = "/redoc",
@@ -72,12 +75,11 @@ async def global_exception_handler(request: Request, exc: Exception):
72
 
73
 
74
  # ─────────────────────────────────────────────
75
- # FAVICON β€” fix 404
76
  # ─────────────────────────────────────────────
77
 
78
  @app.get("/favicon.ico", include_in_schema=False)
79
  async def favicon():
80
- """Returns 204 No Content instead of 404 for favicon requests."""
81
  return Response(status_code=204)
82
 
83
 
@@ -87,10 +89,10 @@ async def favicon():
87
 
88
  @app.get("/health", response_model=HealthResponse, tags=["System"])
89
  async def health():
90
- """Liveness check. Always returns 200. Used by HF Space health monitoring."""
91
  return HealthResponse(
92
  status = "ok",
93
- version = "1.0.0",
94
  uptime = round(time.time() - _startup_time, 2)
95
  )
96
 
@@ -106,8 +108,8 @@ class ResetBody(BaseModel):
106
  @app.post("/reset", response_model=Observation, tags=["Environment"])
107
  async def reset(body: ResetBody = ResetBody()):
108
  """
109
- Starts a fresh episode. Returns the initial Observation the agent sees.
110
- Edge case: always returns valid Observation even if dataset issues occur.
111
  """
112
  try:
113
  obs = environment.reset(
@@ -129,8 +131,9 @@ async def reset(body: ResetBody = ResetBody()):
129
  async def step(action: Action):
130
  """
131
  Submits an action to the environment.
132
- Returns (observation, reward, done, info).
133
- Edge cases: null action, malformed payload, episode already done.
 
134
  """
135
  try:
136
  response = environment.step(action)
@@ -140,8 +143,8 @@ async def step(action: Action):
140
  return StepResponse(
141
  observation = environment._build_observation(),
142
  reward = Reward(
143
- score = -0.1,
144
- breakdown = {"validation_error": -0.1},
145
  feedback = f"Malformed action: {str(e)}"
146
  ),
147
  done = False,
@@ -157,11 +160,7 @@ async def step(action: Action):
157
 
158
  @app.get("/state", response_model=EpisodeState, tags=["Environment"])
159
  async def state():
160
- """
161
- Returns full current environment state.
162
- Works before reset() is called β€” returns default empty state.
163
- Always JSON-serializable. Never crashes.
164
- """
165
  return environment.state()
166
 
167
 
@@ -172,8 +171,8 @@ async def state():
172
  @app.get("/tasks", response_model=TaskListResponse, tags=["Tasks"])
173
  async def tasks():
174
  """
175
- Lists all 15 tasks with full action schema definitions.
176
- Validator checks for action field definitions, not just task names.
177
  """
178
  all_tasks = task_manager.list_all_tasks()
179
  return TaskListResponse(
@@ -191,8 +190,8 @@ async def tasks():
191
  async def grader(request: GraderRequest):
192
  """
193
  Grades a completed episode action.
 
194
  Returns float score strictly between 0.0 and 1.0 exclusive.
195
- Never crashes.
196
  """
197
  try:
198
  if request.action is None:
@@ -201,14 +200,42 @@ async def grader(request: GraderRequest):
201
  feedback = "No action provided for grading.",
202
  breakdown = {"error": "null_action"}
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  score, breakdown, feedback = grade(request.action, request.task_id)
205
- # Clamp strictly between 0 and 1 exclusive
206
  score = max(0.001, min(0.999, score))
207
- return GraderResponse(
208
- score = score,
209
- feedback = feedback,
210
- breakdown = breakdown
211
- )
212
  except Exception as e:
213
  return GraderResponse(
214
  score = 0.001,
@@ -216,6 +243,7 @@ async def grader(request: GraderRequest):
216
  breakdown = {"error": str(e)}
217
  )
218
 
 
219
  # ─────────────────────────────────────────────
220
  # 7. /baseline β€” POST
221
  # ─────────────────────────────────────────────
@@ -223,9 +251,8 @@ async def grader(request: GraderRequest):
223
  @app.post("/baseline", response_model=BaselineResponse, tags=["Baseline"])
224
  async def baseline():
225
  """
226
- Runs the baseline agent against all 3 difficulty levels.
227
- Returns scores JSON. Must complete within 60 seconds.
228
- Edge case: OPENAI_API_KEY not set β†’ continues with rule-based agent.
229
  """
230
  try:
231
  import baseline as baseline_module
@@ -236,45 +263,72 @@ async def baseline():
236
  return results
237
  except asyncio.TimeoutError:
238
  return BaselineResponse(
239
- results=[
240
- BaselineResult(
241
- task_id = "timeout",
242
- difficulty = DifficultyLevel.EASY,
243
- score = 0.0,
244
- steps = 0,
245
- feedback = "Baseline timed out after 55 seconds."
246
- )
247
- ],
248
  average_score=0.0
249
  )
250
  except Exception as e:
251
  return BaselineResponse(
252
- results=[
253
- BaselineResult(
254
- task_id = "error",
255
- difficulty = DifficultyLevel.EASY,
256
- score = 0.0,
257
- steps = 0,
258
- feedback = f"Baseline error: {str(e)}"
259
- )
260
- ],
261
  average_score=0.0
262
  )
263
 
264
 
265
  # ─────────────────────────────────────────────
266
- # ROOT β€” project info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # ─────────────────────────────────────────────
268
 
269
  @app.get("/", tags=["System"])
270
  async def root():
271
  return {
272
- "name": "SQL Query Debugger β€” OpenEnv Environment",
273
- "version": "1.0.0",
 
274
  "docs": "/docs",
275
  "health": "/health",
276
- "endpoints": ["/reset", "/step", "/state", "/tasks", "/grader", "/baseline", "/health"],
277
- "hackathon": "META x PyTorch x SST OpenEnv Hackathon",
278
- "domain": "SQL Query Debugging",
279
- "tasks_count": 15,
280
- }
 
 
 
16
  StepResponse, ResetResponse, TaskListResponse,
17
  BaselineResponse, BaselineResult,
18
  GraderRequest, GraderResponse,
19
+ HealthResponse, TaskInfo, ProgressResponse
20
  )
21
  from env.tasks import task_manager, ACTION_SCHEMA
22
  from env.graders import grade
 
33
  environment.reset(difficulty="easy")
34
  yield
35
 
36
+
37
  # ─────────────────────────────────────────────
38
  # APP DEFINITION
39
  # ─────────────────────────────────────────────
40
 
41
  app = FastAPI(
42
+ title = "SQL Database Engineer Agent β€” OpenEnv Environment",
43
  description = (
44
  "An OpenEnv-compliant reinforcement learning environment where AI agents "
45
+ "learn to act like senior database engineers. "
46
+ "The agent manages a simulated production database over 50+ steps: "
47
+ "inspecting slow queries, creating indexes, rewriting queries, partitioning tables. "
48
+ "Built for the META x PyTorch x SST OpenEnv Hackathon Finals β€” April 25-26, Bangalore."
49
  ),
50
+ version = "2.0.0",
51
  lifespan = lifespan,
52
  docs_url = "/docs",
53
  redoc_url = "/redoc",
 
75
 
76
 
77
  # ─────────────────────────────────────────────
78
+ # FAVICON
79
  # ─────────────────────────────────────────────
80
 
81
  @app.get("/favicon.ico", include_in_schema=False)
82
  async def favicon():
 
83
  return Response(status_code=204)
84
 
85
 
 
89
 
90
  @app.get("/health", response_model=HealthResponse, tags=["System"])
91
  async def health():
92
+ """Liveness check. Always returns 200."""
93
  return HealthResponse(
94
  status = "ok",
95
+ version = "2.0.0",
96
  uptime = round(time.time() - _startup_time, 2)
97
  )
98
 
 
108
  @app.post("/reset", response_model=Observation, tags=["Environment"])
109
  async def reset(body: ResetBody = ResetBody()):
110
  """
111
+ Starts a fresh episode. Initializes DatabaseSimulator.
112
+ Returns the initial Observation with DB state and slow queries.
113
  """
114
  try:
115
  obs = environment.reset(
 
131
  async def step(action: Action):
132
  """
133
  Submits an action to the environment.
134
+ Round 2 actions: inspect_query, create_index, rewrite_query,
135
+ partition_table, analyze_statistics, analyze_indexes, submit_report.
136
+ Returns (observation, reward, done, info) with DB performance delta.
137
  """
138
  try:
139
  response = environment.step(action)
 
143
  return StepResponse(
144
  observation = environment._build_observation(),
145
  reward = Reward(
146
+ score = 0.001,
147
+ breakdown = {"validation_error": 0.001},
148
  feedback = f"Malformed action: {str(e)}"
149
  ),
150
  done = False,
 
160
 
161
  @app.get("/state", response_model=EpisodeState, tags=["Environment"])
162
  async def state():
163
+ """Returns full current environment state including performance history."""
 
 
 
 
164
  return environment.state()
165
 
166
 
 
171
  @app.get("/tasks", response_model=TaskListResponse, tags=["Tasks"])
172
  async def tasks():
173
  """
174
+ Lists all 30 tasks (15 Round 2 scenarios + 15 Round 1 cases).
175
+ Includes complete action schema for all 15 action types.
176
  """
177
  all_tasks = task_manager.list_all_tasks()
178
  return TaskListResponse(
 
190
  async def grader(request: GraderRequest):
191
  """
192
  Grades a completed episode action.
193
+ For Round 2 submit_report: computes score from DB performance improvement.
194
  Returns float score strictly between 0.0 and 1.0 exclusive.
 
195
  """
196
  try:
197
  if request.action is None:
 
200
  feedback = "No action provided for grading.",
201
  breakdown = {"error": "null_action"}
202
  )
203
+
204
+ # Round 2: submit_report grading uses DB state
205
+ if request.action.action_type == ActionType.SUBMIT_REPORT:
206
+ ep_state = environment.state()
207
+ perf_history = ep_state.action_counts.get("_perf_history", [0.0])
208
+ baseline = ep_state.action_counts.get("_baseline_score", 0.0)
209
+ best_score = ep_state.action_counts.get("_best_score", 0.0)
210
+ current = perf_history[-1] if perf_history else 0.0
211
+ max_possible = max(1.0, 100.0 - baseline)
212
+
213
+ perf_improvement = (current - baseline) / max_possible
214
+ step_efficiency = 1.0 - (ep_state.step_count / max(1, 50))
215
+ score = round(
216
+ (perf_improvement * 0.60) + (step_efficiency * 0.20) + 0.10, 4
217
+ )
218
+ score = max(0.001, min(0.999, score))
219
+
220
+ return GraderResponse(
221
+ score = score,
222
+ feedback = (
223
+ f"DB performance: {baseline:.1f} β†’ {current:.1f} "
224
+ f"(best: {best_score:.1f}). "
225
+ f"Steps used: {ep_state.step_count}/50."
226
+ ),
227
+ breakdown = {
228
+ "perf_improvement": round(perf_improvement, 4),
229
+ "step_efficiency": round(step_efficiency, 4),
230
+ "base_score": 0.10,
231
+ }
232
+ )
233
+
234
+ # Round 1 grading
235
  score, breakdown, feedback = grade(request.action, request.task_id)
 
236
  score = max(0.001, min(0.999, score))
237
+ return GraderResponse(score=score, feedback=feedback, breakdown=breakdown)
238
+
 
 
 
239
  except Exception as e:
240
  return GraderResponse(
241
  score = 0.001,
 
243
  breakdown = {"error": str(e)}
244
  )
245
 
246
+
247
  # ─────────────────────────────────────────────
248
  # 7. /baseline β€” POST
249
  # ─────────────────────────────────────────────
 
251
  @app.post("/baseline", response_model=BaselineResponse, tags=["Baseline"])
252
  async def baseline():
253
  """
254
+ Runs the baseline agent against all difficulty levels.
255
+ Must complete within 60 seconds.
 
256
  """
257
  try:
258
  import baseline as baseline_module
 
263
  return results
264
  except asyncio.TimeoutError:
265
  return BaselineResponse(
266
+ results=[BaselineResult(
267
+ task_id="timeout", difficulty=DifficultyLevel.EASY,
268
+ score=0.0, steps=0, feedback="Baseline timed out."
269
+ )],
 
 
 
 
 
270
  average_score=0.0
271
  )
272
  except Exception as e:
273
  return BaselineResponse(
274
+ results=[BaselineResult(
275
+ task_id="error", difficulty=DifficultyLevel.EASY,
276
+ score=0.0, steps=0, feedback=f"Baseline error: {str(e)}"
277
+ )],
 
 
 
 
 
278
  average_score=0.0
279
  )
280
 
281
 
282
  # ─────────────────────────────────────────────
283
+ # 8. /progress β€” GET (Round 2 NEW)
284
+ # ───────────────────────────────────────���─────
285
+
286
+ @app.get("/progress", response_model=ProgressResponse, tags=["Training"])
287
+ async def progress():
288
+ """
289
+ Returns DB performance history for training visualization.
290
+ Used by evaluate_agent.py to generate reward curves.
291
+ Shows improvement from baseline to current score.
292
+ """
293
+ ep_state = environment.state()
294
+ ac = ep_state.action_counts
295
+ perf_history = ac.get("_perf_history", [])
296
+ milestones = ac.get("_milestones", [])
297
+ baseline = ac.get("_baseline_score", 0.0)
298
+ target = ac.get("_target_score", 85.0)
299
+ best = ac.get("_best_score", 0.0)
300
+ current = perf_history[-1] if perf_history else 0.0
301
+
302
+ return ProgressResponse(
303
+ scenario_id = ep_state.task_id,
304
+ performance_score = current,
305
+ baseline_score = baseline,
306
+ target_score = target,
307
+ improvement_history = perf_history,
308
+ milestones_earned = milestones,
309
+ best_score = best,
310
+ steps_used = ep_state.step_count,
311
+ budget_remaining = max(0, 50 - ep_state.step_count),
312
+ total_reward = ep_state.total_reward,
313
+ )
314
+
315
+
316
+ # ─────────────────────────────────────────────
317
+ # ROOT
318
  # ─────────────────────────────────────────────
319
 
320
  @app.get("/", tags=["System"])
321
  async def root():
322
  return {
323
+ "name": "SQL Database Engineer Agent β€” OpenEnv Environment",
324
+ "version": "2.0.0",
325
+ "tagline": "Training LLMs to act like senior database engineers",
326
  "docs": "/docs",
327
  "health": "/health",
328
+ "endpoints": ["/reset", "/step", "/state", "/tasks", "/grader", "/baseline", "/progress", "/health"],
329
+ "hackathon": "META x PyTorch x SST OpenEnv Hackathon β€” Finals April 25-26 Bangalore",
330
+ "domain": "Long-Horizon Database Engineering",
331
+ "tasks_count": 30,
332
+ "max_steps": 50,
333
+ "themes": ["Long-Horizon Planning", "World Modeling", "Self-Improvement", "Wildcard"],
334
+ }
blog/mini_blog.md ADDED
File without changes
dataset/easy_scenarios.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "easy_s001",
4
+ "description": "User lookup query taking 2s on 10K users table. Missing index on email column.",
5
+ "tables": [
6
+ {"name": "users", "rows": 10000, "indexes": ["PRIMARY"], "size_mb": 8}
7
+ ],
8
+ "slow_queries": [
9
+ {"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000, "main_table": "users", "rows_examined": 10000}
10
+ ],
11
+ "missing_index_hints": [
12
+ {"table": "users", "columns": ["email"], "reason": "email is used in WHERE clause but has no index"}
13
+ ],
14
+ "performance_score_baseline": 8.0,
15
+ "target_score": 80.0,
16
+ "max_steps": 15,
17
+ "optimal_actions": ["inspect_query:q1", "analyze_indexes:users", "create_index:users:email", "submit_report"],
18
+ "category": "indexing"
19
+ },
20
+ {
21
+ "id": "easy_s002",
22
+ "description": "Order status query scanning 50K orders. Composite index on user_id + status needed.",
23
+ "tables": [
24
+ {"name": "orders", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 120}
25
+ ],
26
+ "slow_queries": [
27
+ {"id": "q1", "sql": "SELECT * FROM orders WHERE user_id=? AND status=?", "avg_ms": 3500, "main_table": "orders", "rows_examined": 50000}
28
+ ],
29
+ "missing_index_hints": [
30
+ {"table": "orders", "columns": ["user_id", "status"], "reason": "Composite WHERE clause needs composite index"}
31
+ ],
32
+ "performance_score_baseline": 5.0,
33
+ "target_score": 85.0,
34
+ "max_steps": 15,
35
+ "optimal_actions": ["inspect_query:q1", "create_index:orders:user_id,status", "submit_report"],
36
+ "category": "indexing"
37
+ },
38
+ {
39
+ "id": "easy_s003",
40
+ "description": "Product search query doing full table scan on 20K products. Index on name column fixes it.",
41
+ "tables": [
42
+ {"name": "products", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 35}
43
+ ],
44
+ "slow_queries": [
45
+ {"id": "q1", "sql": "SELECT id, name, price FROM products WHERE name LIKE ?", "avg_ms": 1800, "main_table": "products", "rows_examined": 20000}
46
+ ],
47
+ "missing_index_hints": [
48
+ {"table": "products", "columns": ["name"], "reason": "LIKE queries benefit from index on name"}
49
+ ],
50
+ "performance_score_baseline": 10.0,
51
+ "target_score": 78.0,
52
+ "max_steps": 15,
53
+ "optimal_actions": ["inspect_query:q1", "create_index:products:name", "submit_report"],
54
+ "category": "indexing"
55
+ },
56
+ {
57
+ "id": "easy_s004",
58
+ "description": "Session lookup hitting 15K sessions table without index. Single index solves it.",
59
+ "tables": [
60
+ {"name": "sessions", "rows": 15000, "indexes": ["PRIMARY"], "size_mb": 12}
61
+ ],
62
+ "slow_queries": [
63
+ {"id": "q1", "sql": "SELECT * FROM sessions WHERE user_id=? AND expires_at > NOW()", "avg_ms": 1500, "main_table": "sessions", "rows_examined": 15000}
64
+ ],
65
+ "missing_index_hints": [
66
+ {"table": "sessions", "columns": ["user_id", "expires_at"], "reason": "Composite index on user_id + expires_at needed"}
67
+ ],
68
+ "performance_score_baseline": 12.0,
69
+ "target_score": 80.0,
70
+ "max_steps": 15,
71
+ "optimal_actions": ["inspect_query:q1", "create_index:sessions:user_id,expires_at", "submit_report"],
72
+ "category": "indexing"
73
+ },
74
+ {
75
+ "id": "easy_s005",
76
+ "description": "Log table growing to 30K entries. Query filtering by level and created_at is slow.",
77
+ "tables": [
78
+ {"name": "logs", "rows": 30000, "indexes": ["PRIMARY"], "size_mb": 50}
79
+ ],
80
+ "slow_queries": [
81
+ {"id": "q1", "sql": "SELECT * FROM logs WHERE level=? AND created_at > ?", "avg_ms": 2200, "main_table": "logs", "rows_examined": 30000}
82
+ ],
83
+ "missing_index_hints": [
84
+ {"table": "logs", "columns": ["level", "created_at"], "reason": "Compound filter needs compound index"}
85
+ ],
86
+ "performance_score_baseline": 7.8,
87
+ "target_score": 80.0,
88
+ "max_steps": 15,
89
+ "optimal_actions": ["inspect_query:q1", "create_index:logs:level,created_at", "submit_report"],
90
+ "category": "indexing"
91
+ }
92
+ ]
dataset/hard_scenarios.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "hard_s001",
4
+ "description": "Financial DB: 500K transactions across 4 tables. 3 slow queries. Needs indexes, partition, and statistics.",
5
+ "tables": [
6
+ {"name": "transactions", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 2400},
7
+ {"name": "accounts", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 80},
8
+ {"name": "customers", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 120},
9
+ {"name": "audit_log", "rows": 1000000,"indexes": ["PRIMARY"], "size_mb": 5000}
10
+ ],
11
+ "slow_queries": [
12
+ {"id": "q1", "sql": "SELECT * FROM transactions WHERE account_id=? AND status=? AND created_at > ?", "avg_ms": 15000, "main_table": "transactions", "rows_examined": 500000},
13
+ {"id": "q2", "sql": "SELECT c.*, COUNT(t.id) FROM customers c, transactions t WHERE c.id = t.customer_id AND t.amount > ? GROUP BY c.id", "avg_ms": 22000, "main_table": "transactions", "rows_examined": 500000},
14
+ {"id": "q3", "sql": "SELECT * FROM audit_log WHERE entity_id=? AND entity_type=? ORDER BY created_at DESC LIMIT 100", "avg_ms": 18000, "main_table": "audit_log", "rows_examined": 1000000}
15
+ ],
16
+ "missing_index_hints": [
17
+ {"table": "transactions", "columns": ["account_id", "status", "created_at"], "reason": "Composite filter β€” high cardinality"},
18
+ {"table": "transactions", "columns": ["customer_id", "amount"], "reason": "JOIN + range filter"},
19
+ {"table": "audit_log", "columns": ["entity_id", "entity_type", "created_at"], "reason": "Lookup + ORDER BY on huge table"}
20
+ ],
21
+ "performance_score_baseline": 4.2,
22
+ "target_score": 70.0,
23
+ "max_steps": 50,
24
+ "optimal_actions": [
25
+ "inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
26
+ "analyze_indexes:transactions", "analyze_indexes:audit_log",
27
+ "create_index:transactions:account_id,status,created_at",
28
+ "create_index:transactions:customer_id,amount",
29
+ "create_index:audit_log:entity_id,entity_type,created_at",
30
+ "rewrite_query:q2:SELECT c.id, c.name, COUNT(t.id) as tx_count FROM customers c INNER JOIN transactions t ON c.id = t.customer_id WHERE t.amount > ? GROUP BY c.id, c.name",
31
+ "partition_table:audit_log",
32
+ "analyze_statistics:transactions",
33
+ "analyze_statistics:audit_log",
34
+ "submit_report"
35
+ ],
36
+ "category": "financial"
37
+ },
38
+ {
39
+ "id": "hard_s002",
40
+ "description": "SaaS platform: 8-table schema, 200K+ records. Dashboard queries taking 20s+. Full optimization campaign.",
41
+ "tables": [
42
+ {"name": "workspaces", "rows": 5000, "indexes": ["PRIMARY"], "size_mb": 10},
43
+ {"name": "users", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 120},
44
+ {"name": "projects", "rows": 200000, "indexes": ["PRIMARY"], "size_mb": 450},
45
+ {"name": "tasks", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 3000},
46
+ {"name": "comments", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1800},
47
+ {"name": "attachments", "rows": 300000, "indexes": ["PRIMARY"], "size_mb": 900},
48
+ {"name": "activity_log", "rows": 2000000,"indexes": ["PRIMARY"], "size_mb": 8000},
49
+ {"name": "notifications", "rows": 400000, "indexes": ["PRIMARY"], "size_mb": 600}
50
+ ],
51
+ "slow_queries": [
52
+ {"id": "q1", "sql": "SELECT * FROM tasks WHERE project_id=? AND assignee_id=? AND status != 'done' ORDER BY due_date ASC", "avg_ms": 20000, "main_table": "tasks", "rows_examined": 800000},
53
+ {"id": "q2", "sql": "SELECT * FROM activity_log WHERE workspace_id=? AND created_at > ? ORDER BY created_at DESC LIMIT 50", "avg_ms": 25000, "main_table": "activity_log", "rows_examined": 2000000},
54
+ {"id": "q3", "sql": "SELECT * FROM notifications WHERE user_id=? AND read=0", "avg_ms": 8000, "main_table": "notifications", "rows_examined": 400000}
55
+ ],
56
+ "missing_index_hints": [
57
+ {"table": "tasks", "columns": ["project_id", "assignee_id", "status", "due_date"], "reason": "4-column filter + ORDER BY"},
58
+ {"table": "activity_log", "columns": ["workspace_id", "created_at"], "reason": "Range query on 2M row table β€” also partition candidate"},
59
+ {"table": "notifications","columns": ["user_id", "read"], "reason": "Hot path β€” unread notifications per user"}
60
+ ],
61
+ "performance_score_baseline": 3.8,
62
+ "target_score": 68.0,
63
+ "max_steps": 50,
64
+ "optimal_actions": [
65
+ "inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
66
+ "analyze_indexes:tasks", "analyze_indexes:activity_log",
67
+ "create_index:tasks:project_id,assignee_id,status,due_date",
68
+ "create_index:activity_log:workspace_id,created_at",
69
+ "create_index:notifications:user_id,read",
70
+ "partition_table:activity_log",
71
+ "analyze_statistics:tasks",
72
+ "analyze_statistics:activity_log",
73
+ "submit_report"
74
+ ],
75
+ "category": "saas_platform"
76
+ },
77
+ {
78
+ "id": "hard_s003",
79
+ "description": "Healthcare DB: 1M patient records. Compliance queries + clinical search + audit trail all slow.",
80
+ "tables": [
81
+ {"name": "patients", "rows": 1000000, "indexes": ["PRIMARY"], "size_mb": 4000},
82
+ {"name": "appointments", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1500},
83
+ {"name": "prescriptions", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 2500},
84
+ {"name": "clinical_notes", "rows": 1200000, "indexes": ["PRIMARY"], "size_mb": 6000}
85
+ ],
86
+ "slow_queries": [
87
+ {"id": "q1", "sql": "SELECT * FROM appointments WHERE patient_id=? AND doctor_id=? AND appointment_date BETWEEN ? AND ?", "avg_ms": 18000, "main_table": "appointments", "rows_examined": 500000},
88
+ {"id": "q2", "sql": "SELECT * FROM prescriptions WHERE patient_id=? AND medication_code=? AND prescribed_at > ?", "avg_ms": 14000, "main_table": "prescriptions", "rows_examined": 800000},
89
+ {"id": "q3", "sql": "SELECT * FROM clinical_notes WHERE patient_id=? ORDER BY created_at DESC LIMIT 20", "avg_ms": 22000, "main_table": "clinical_notes", "rows_examined": 1200000}
90
+ ],
91
+ "missing_index_hints": [
92
+ {"table": "appointments", "columns": ["patient_id", "doctor_id", "appointment_date"], "reason": "Date range query + 2 foreign keys"},
93
+ {"table": "prescriptions", "columns": ["patient_id", "medication_code", "prescribed_at"], "reason": "Patient medication history"},
94
+ {"table": "clinical_notes","columns": ["patient_id", "created_at"], "reason": "Sorted history per patient on 1.2M rows"}
95
+ ],
96
+ "performance_score_baseline": 3.5,
97
+ "target_score": 68.0,
98
+ "max_steps": 50,
99
+ "optimal_actions": [
100
+ "inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
101
+ "analyze_indexes:appointments", "analyze_indexes:clinical_notes",
102
+ "create_index:appointments:patient_id,doctor_id,appointment_date",
103
+ "create_index:prescriptions:patient_id,medication_code,prescribed_at",
104
+ "create_index:clinical_notes:patient_id,created_at",
105
+ "partition_table:clinical_notes",
106
+ "analyze_statistics:appointments",
107
+ "analyze_statistics:clinical_notes",
108
+ "submit_report"
109
+ ],
110
+ "category": "healthcare"
111
+ },
112
+ {
113
+ "id": "hard_s004",
114
+ "description": "Gaming leaderboard: 2M player records. Real-time ranking + history + match queries all degraded.",
115
+ "tables": [
116
+ {"name": "players", "rows": 2000000, "indexes": ["PRIMARY"], "size_mb": 5000},
117
+ {"name": "matches", "rows": 5000000, "indexes": ["PRIMARY"], "size_mb": 15000},
118
+ {"name": "leaderboards", "rows": 2000000, "indexes": ["PRIMARY"], "size_mb": 4000},
119
+ {"name": "achievements", "rows": 800000, "indexes": ["PRIMARY"], "size_mb": 2000}
120
+ ],
121
+ "slow_queries": [
122
+ {"id": "q1", "sql": "SELECT * FROM leaderboards WHERE game_mode=? AND season=? ORDER BY score DESC LIMIT 100", "avg_ms": 30000, "main_table": "leaderboards", "rows_examined": 2000000},
123
+ {"id": "q2", "sql": "SELECT * FROM matches WHERE player_id=? AND game_mode=? AND played_at > ? ORDER BY played_at DESC", "avg_ms": 25000, "main_table": "matches", "rows_examined": 5000000},
124
+ {"id": "q3", "sql": "SELECT * FROM achievements WHERE player_id=? AND unlocked=1", "avg_ms": 12000, "main_table": "achievements", "rows_examined": 800000}
125
+ ],
126
+ "missing_index_hints": [
127
+ {"table": "leaderboards", "columns": ["game_mode", "season", "score"], "reason": "Sorted leaderboard by mode+season"},
128
+ {"table": "matches", "columns": ["player_id", "game_mode", "played_at"], "reason": "Player history β€” 5M rows"},
129
+ {"table": "achievements", "columns": ["player_id", "unlocked"], "reason": "Unlocked achievements per player"}
130
+ ],
131
+ "performance_score_baseline": 2.8,
132
+ "target_score": 65.0,
133
+ "max_steps": 50,
134
+ "optimal_actions": [
135
+ "inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
136
+ "analyze_indexes:leaderboards", "analyze_indexes:matches",
137
+ "create_index:leaderboards:game_mode,season,score",
138
+ "create_index:matches:player_id,game_mode,played_at",
139
+ "create_index:achievements:player_id,unlocked",
140
+ "partition_table:matches",
141
+ "analyze_statistics:leaderboards",
142
+ "analyze_statistics:matches",
143
+ "submit_report"
144
+ ],
145
+ "category": "gaming"
146
+ },
147
+ {
148
+ "id": "hard_s005",
149
+ "description": "Logistics platform: 6 tables, 3M shipment records. ETA queries, route optimization, and reporting all slow.",
150
+ "tables": [
151
+ {"name": "shipments", "rows": 3000000, "indexes": ["PRIMARY"], "size_mb": 9000},
152
+ {"name": "routes", "rows": 500000, "indexes": ["PRIMARY"], "size_mb": 1500},
153
+ {"name": "drivers", "rows": 100000, "indexes": ["PRIMARY"], "size_mb": 200},
154
+ {"name": "vehicles", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 150},
155
+ {"name": "warehouses", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 40},
156
+ {"name": "tracking", "rows": 10000000,"indexes": ["PRIMARY"], "size_mb": 30000}
157
+ ],
158
+ "slow_queries": [
159
+ {"id": "q1", "sql": "SELECT * FROM shipments WHERE origin_warehouse=? AND status=? AND scheduled_at BETWEEN ? AND ?", "avg_ms": 28000, "main_table": "shipments", "rows_examined": 3000000},
160
+ {"id": "q2", "sql": "SELECT * FROM tracking WHERE shipment_id=? ORDER BY recorded_at DESC LIMIT 50", "avg_ms": 35000, "main_table": "tracking", "rows_examined": 10000000},
161
+ {"id": "q3", "sql": "SELECT d.*, COUNT(s.id) FROM drivers d, shipments s WHERE d.id = s.driver_id AND s.status='in_transit' GROUP BY d.id", "avg_ms": 20000, "main_table": "shipments", "rows_examined": 3000000}
162
+ ],
163
+ "missing_index_hints": [
164
+ {"table": "shipments", "columns": ["origin_warehouse", "status", "scheduled_at"], "reason": "3-column filter on 3M rows"},
165
+ {"table": "tracking", "columns": ["shipment_id", "recorded_at"], "reason": "Lookup + sort on 10M row table β€” partition candidate"},
166
+ {"table": "shipments", "columns": ["driver_id", "status"], "reason": "JOIN + WHERE filter for driver stats"}
167
+ ],
168
+ "performance_score_baseline": 2.5,
169
+ "target_score": 65.0,
170
+ "max_steps": 50,
171
+ "optimal_actions": [
172
+ "inspect_query:q1", "inspect_query:q2", "inspect_query:q3",
173
+ "analyze_indexes:shipments", "analyze_indexes:tracking",
174
+ "create_index:shipments:origin_warehouse,status,scheduled_at",
175
+ "create_index:tracking:shipment_id,recorded_at",
176
+ "create_index:shipments:driver_id,status",
177
+ "rewrite_query:q3:SELECT d.id, d.name, COUNT(s.id) as active_shipments FROM drivers d INNER JOIN shipments s ON d.id = s.driver_id WHERE s.status='in_transit' GROUP BY d.id, d.name",
178
+ "partition_table:tracking",
179
+ "analyze_statistics:shipments",
180
+ "analyze_statistics:tracking",
181
+ "submit_report"
182
+ ],
183
+ "category": "logistics"
184
+ }
185
+ ]
dataset/medium_scenarios.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "medium_s001",
4
+ "description": "E-commerce DB: 50K orders + 8K users. Two slow queries. Composite indexes + statistics update needed.",
5
+ "tables": [
6
+ {"name": "orders", "rows": 50000, "indexes": ["PRIMARY"], "size_mb": 280},
7
+ {"name": "users", "rows": 8000, "indexes": ["PRIMARY", "email_idx"], "size_mb": 15}
8
+ ],
9
+ "slow_queries": [
10
+ {"id": "q1", "sql": "SELECT * FROM orders WHERE user_id=? AND status=?", "avg_ms": 8500, "main_table": "orders", "rows_examined": 50000},
11
+ {"id": "q2", "sql": "SELECT COUNT(*) FROM orders o JOIN users u ON o.user_id=u.id WHERE u.country=?", "avg_ms": 3200, "main_table": "orders", "rows_examined": 50000}
12
+ ],
13
+ "missing_index_hints": [
14
+ {"table": "orders", "columns": ["user_id", "status"], "reason": "Composite WHERE filter"},
15
+ {"table": "users", "columns": ["country"], "reason": "JOIN + WHERE filter on country"}
16
+ ],
17
+ "performance_score_baseline": 12.5,
18
+ "target_score": 75.0,
19
+ "max_steps": 25,
20
+ "optimal_actions": [
21
+ "inspect_query:q1", "inspect_query:q2",
22
+ "analyze_indexes:orders", "analyze_indexes:users",
23
+ "create_index:orders:user_id,status",
24
+ "create_index:users:country",
25
+ "analyze_statistics:orders",
26
+ "submit_report"
27
+ ],
28
+ "category": "multi_table"
29
+ },
30
+ {
31
+ "id": "medium_s002",
32
+ "description": "Blog platform: 100K posts + 20K authors. Search and author lookup queries both slow.",
33
+ "tables": [
34
+ {"name": "posts", "rows": 100000, "indexes": ["PRIMARY"], "size_mb": 450},
35
+ {"name": "authors", "rows": 20000, "indexes": ["PRIMARY"], "size_mb": 40}
36
+ ],
37
+ "slow_queries": [
38
+ {"id": "q1", "sql": "SELECT * FROM posts WHERE author_id=? AND published=1 ORDER BY created_at DESC", "avg_ms": 6000, "main_table": "posts", "rows_examined": 100000},
39
+ {"id": "q2", "sql": "SELECT * FROM authors WHERE username=?", "avg_ms": 2100, "main_table": "authors", "rows_examined": 20000}
40
+ ],
41
+ "missing_index_hints": [
42
+ {"table": "posts", "columns": ["author_id", "published", "created_at"], "reason": "Multi-column filter + ORDER BY"},
43
+ {"table": "authors", "columns": ["username"], "reason": "Unique lookup by username"}
44
+ ],
45
+ "performance_score_baseline": 9.0,
46
+ "target_score": 78.0,
47
+ "max_steps": 25,
48
+ "optimal_actions": [
49
+ "inspect_query:q1", "inspect_query:q2",
50
+ "create_index:posts:author_id,published,created_at",
51
+ "create_index:authors:username",
52
+ "submit_report"
53
+ ],
54
+ "category": "multi_table"
55
+ },
56
+ {
57
+ "id": "medium_s003",
58
+ "description": "Inventory system: 80K products + 200K stock movements. Two queries needing index + rewrite.",
59
+ "tables": [
60
+ {"name": "products", "rows": 80000, "indexes": ["PRIMARY"], "size_mb": 200},
61
+ {"name": "stock_movements", "rows": 200000, "indexes": ["PRIMARY"], "size_mb": 600}
62
+ ],
63
+ "slow_queries": [
64
+ {"id": "q1", "sql": "SELECT * FROM stock_movements WHERE product_id=? AND movement_type=? AND created_at > ?", "avg_ms": 9000, "main_table": "stock_movements", "rows_examined": 200000},
65
+ {"id": "q2", "sql": "SELECT p.*, SUM(sm.quantity) FROM products p, stock_movements sm WHERE p.id = sm.product_id GROUP BY p.id", "avg_ms": 12000, "main_table": "products", "rows_examined": 200000}
66
+ ],
67
+ "missing_index_hints": [
68
+ {"table": "stock_movements", "columns": ["product_id", "movement_type", "created_at"], "reason": "Composite filter on 3 columns"},
69
+ {"table": "products", "columns": ["id"], "reason": "JOIN column β€” rewrite implicit JOIN to INNER JOIN"}
70
+ ],
71
+ "performance_score_baseline": 6.5,
72
+ "target_score": 72.0,
73
+ "max_steps": 30,
74
+ "optimal_actions": [
75
+ "inspect_query:q1", "inspect_query:q2",
76
+ "create_index:stock_movements:product_id,movement_type,created_at",
77
+ "rewrite_query:q2:SELECT p.id, p.name, SUM(sm.quantity) FROM products p INNER JOIN stock_movements sm ON p.id = sm.product_id GROUP BY p.id",
78
+ "analyze_statistics:stock_movements",
79
+ "submit_report"
80
+ ],
81
+ "category": "rewrite_and_index"
82
+ },
83
+ {
84
+ "id": "medium_s004",
85
+ "description": "Ticketing system: 60K tickets + 5K agents. Status queue and agent workload queries are slow.",
86
+ "tables": [
87
+ {"name": "tickets", "rows": 60000, "indexes": ["PRIMARY"], "size_mb": 180},
88
+ {"name": "agents", "rows": 5000, "indexes": ["PRIMARY"], "size_mb": 8}
89
+ ],
90
+ "slow_queries": [
91
+ {"id": "q1", "sql": "SELECT * FROM tickets WHERE status=? AND priority=? ORDER BY created_at ASC", "avg_ms": 5500, "main_table": "tickets", "rows_examined": 60000},
92
+ {"id": "q2", "sql": "SELECT agent_id, COUNT(*) as open_count FROM tickets WHERE status='open' GROUP BY agent_id", "avg_ms": 4200, "main_table": "tickets", "rows_examined": 60000}
93
+ ],
94
+ "missing_index_hints": [
95
+ {"table": "tickets", "columns": ["status", "priority", "created_at"], "reason": "Three-column filter with ORDER BY"},
96
+ {"table": "tickets", "columns": ["status", "agent_id"], "reason": "GROUP BY + WHERE filter"}
97
+ ],
98
+ "performance_score_baseline": 11.0,
99
+ "target_score": 76.0,
100
+ "max_steps": 25,
101
+ "optimal_actions": [
102
+ "inspect_query:q1", "inspect_query:q2",
103
+ "analyze_indexes:tickets",
104
+ "create_index:tickets:status,priority,created_at",
105
+ "create_index:tickets:status,agent_id",
106
+ "submit_report"
107
+ ],
108
+ "category": "multi_index"
109
+ },
110
+ {
111
+ "id": "medium_s005",
112
+ "description": "Analytics DB: 150K events + 10K users. Event funnel query and user lookup both need optimization.",
113
+ "tables": [
114
+ {"name": "events", "rows": 150000, "indexes": ["PRIMARY"], "size_mb": 700},
115
+ {"name": "users", "rows": 10000, "indexes": ["PRIMARY"], "size_mb": 20}
116
+ ],
117
+ "slow_queries": [
118
+ {"id": "q1", "sql": "SELECT * FROM events WHERE user_id=? AND event_type=? AND occurred_at BETWEEN ? AND ?", "avg_ms": 11000, "main_table": "events", "rows_examined": 150000},
119
+ {"id": "q2", "sql": "SELECT * FROM users WHERE signup_source=? AND created_at > ?", "avg_ms": 3000, "main_table": "users", "rows_examined": 10000}
120
+ ],
121
+ "missing_index_hints": [
122
+ {"table": "events", "columns": ["user_id", "event_type", "occurred_at"], "reason": "Range query on 3 columns"},
123
+ {"table": "users", "columns": ["signup_source", "created_at"], "reason": "Composite filter on signup data"}
124
+ ],
125
+ "performance_score_baseline": 5.5,
126
+ "target_score": 74.0,
127
+ "max_steps": 30,
128
+ "optimal_actions": [
129
+ "inspect_query:q1", "inspect_query:q2",
130
+ "create_index:events:user_id,event_type,occurred_at",
131
+ "create_index:users:signup_source,created_at",
132
+ "analyze_statistics:events",
133
+ "submit_report"
134
+ ],
135
+ "category": "analytics"
136
+ }
137
+ ]
env/__pycache__/models.cpython-312.pyc CHANGED
Binary files a/env/__pycache__/models.cpython-312.pyc and b/env/__pycache__/models.cpython-312.pyc differ
 
env/environment.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import time
2
  import random
3
  from typing import Optional
@@ -9,29 +15,28 @@ from env.models import (
9
  )
10
  from env.tasks import task_manager
11
  from env.reward import compute_reward, is_done, MAX_STEPS
 
12
 
13
 
14
  class SQLDebuggerEnvironment:
15
  """
16
- OpenEnv-compliant SQL Query Debugger Environment.
17
-
18
- Implements the 3 required methods:
19
- reset() β†’ Observation
20
- step() β†’ (Observation, Reward, done, info)
21
- state() β†’ EpisodeState
22
-
23
- Design principles:
24
- - Dense reward signal at every step
25
- - No state leakage between episodes
26
- - Graceful handling of all edge cases
27
- - Deterministic grading
28
- - Thread-safe episode state
29
  """
30
 
31
  def __init__(self):
32
- self._state = EpisodeState()
33
- self._current_task = None
34
- self._started_at = None
 
 
 
35
 
36
  # ─────────────────────────────────────────────
37
  # reset() β†’ Observation
@@ -39,14 +44,8 @@ class SQLDebuggerEnvironment:
39
 
40
  def reset(self, difficulty: Optional[str] = None, task_id: Optional[str] = None) -> Observation:
41
  """
42
- Starts a fresh episode. Clears ALL state from previous episode.
43
- Loads a new task from the dataset.
44
- Returns the initial Observation the agent sees.
45
-
46
- Edge cases handled:
47
- - reset() called mid-episode β†’ cleanly resets, no state leakage
48
- - invalid difficulty β†’ defaults to random
49
- - dataset empty β†’ raises ValueError with clear message
50
  """
51
 
52
  # ── Resolve difficulty ────────────────────────────────────
@@ -54,7 +53,6 @@ class SQLDebuggerEnvironment:
54
  try:
55
  diff_enum = DifficultyLevel(difficulty.lower())
56
  except ValueError:
57
- # Invalid difficulty β€” pick random
58
  diff_enum = random.choice(list(DifficultyLevel))
59
  else:
60
  diff_enum = random.choice(list(DifficultyLevel))
@@ -65,7 +63,18 @@ class SQLDebuggerEnvironment:
65
  except Exception as e:
66
  raise ValueError(f"Failed to load task: {str(e)}")
67
 
68
- # ── Reset ALL state β€” no leakage ──────────────────────────
 
 
 
 
 
 
 
 
 
 
 
69
  self._current_task = task
70
  self._started_at = time.time()
71
  self._state = EpisodeState(
@@ -76,147 +85,194 @@ class SQLDebuggerEnvironment:
76
  done = False,
77
  hints_used = 0,
78
  previous_actions = [],
79
- action_counts = {},
 
 
 
 
 
 
80
  started_at = self._started_at,
81
  last_reward = 0.0,
82
  initialized = True,
83
  )
84
 
85
- # ── Build initial observation ─────────────────────────────
86
- context = task_manager.build_observation_context(task)
87
- return Observation(
88
- task_id = task["id"],
89
- task_description = task["description"],
90
- current_context = context,
91
- step_count = 0,
92
- difficulty = diff_enum,
93
- max_steps = MAX_STEPS,
94
- hints_used = 0,
95
- previous_actions = [],
96
- metadata = {
97
- "category": task.get("category", ""),
98
- "estimated_steps": task.get("estimated_fix_steps", 5),
99
- "started_at": self._started_at,
100
- }
101
- )
102
 
103
  # ─────────────────────────────────────────────
104
- # step() β†’ (Observation, Reward, done, info)
105
  # ─────────────────────────────────────────────
106
 
107
  def step(self, action: Optional[Action]) -> StepResponse:
108
  """
109
- Accepts an Action, processes it, updates state,
110
- computes dense reward, returns next Observation.
111
-
112
- Edge cases handled:
113
- - step() called before reset() β†’ auto-resets
114
- - null action β†’ reward=-0.1, done=False, never crash
115
- - malformed action payload β†’ catches ValidationError
116
- - agent loops (same action 3+ times) β†’ loop penalty
117
- - episode already done β†’ returns terminal observation
118
- - max steps reached β†’ forces done=True
119
- - extremely long payload β†’ truncated in models.py
120
  """
121
 
122
  # ── Auto-reset if not initialized ────────────────────────
123
  if not self._state.initialized or self._current_task is None:
124
  obs = self.reset()
125
  return StepResponse(
126
- observation=obs,
127
- reward=Reward(score=0.5, breakdown={"auto_reset": True}, feedback="Environment auto-reset."),
128
- done=False,
129
- info={"auto_reset": True}
130
  )
131
 
132
  # ── Episode already done ──────────────────────────────────
133
  if self._state.done:
134
  obs = self._build_observation()
135
  return StepResponse(
136
- observation=obs,
137
- reward=Reward(score=0.5, breakdown={"episode_done": True}, feedback="Episode already finished. Call reset()."),
138
- done=True,
139
- info={"episode_done": True, "total_reward": self._state.total_reward}
140
  )
141
 
142
- # ── Handle null / invalid action ─────────────────────────
143
  if action is None or action.payload is None:
144
  self._state.step_count += 1
145
- obs = self._build_observation()
146
- reward = Reward(
147
- score=0.001,
148
- breakdown={"invalid_action": 0.001},
149
- feedback="Null or invalid action received."
150
- )
151
- self._state.last_reward = -0.1
152
- self._state.total_reward = round(self._state.total_reward - 0.1, 4)
153
- done = self._state.step_count >= MAX_STEPS
154
  self._state.done = done
155
  return StepResponse(observation=obs, reward=reward, done=done, info={"error": "null_action"})
156
 
157
- # ── Validate action type ──────────────────────────────────
158
- try:
159
- action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
160
- except Exception:
161
- action_type_val = "unknown"
162
 
163
  # ── Update step count ───────────────────────────���─────────
164
  self._state.step_count += 1
165
  self._state.previous_actions.append(action_type_val)
166
- self._state.action_counts[action_type_val] = self._state.action_counts.get(action_type_val, 0) + 1
 
167
 
168
- # ── Track hints ───────────────────────────────────────────
169
- if action.action_type == ActionType.REQUEST_HINT:
170
  self._state.hints_used += 1
171
- # Inject hint into next observation context
172
  hint_text = task_manager.get_hint(self._current_task, self._state.hints_used)
173
  self._current_task["_last_hint"] = hint_text
174
 
175
- # ── Compute dense reward ──────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  reward = compute_reward(
177
- action = action,
178
- task_id = self._state.task_id,
179
- difficulty = self._state.difficulty,
180
- step_count = self._state.step_count,
181
- previous_actions = self._state.previous_actions[:-1], # exclude current
182
- hints_used = self._state.hints_used,
183
- estimated_steps = self._current_task.get("estimated_fix_steps", 5),
184
- action_counts = self._state.action_counts,
 
 
 
 
185
  )
186
 
 
 
 
187
  # ── Update cumulative reward ──────────────────────────────
188
  self._state.last_reward = reward.score
189
  self._state.total_reward = round(self._state.total_reward + reward.score, 4)
190
 
191
- # ── Check done condition ──────────────────────────────────
 
 
 
192
  done = is_done(
193
- action_type = action.action_type,
194
- step_count = self._state.step_count,
195
- grader_score = reward.breakdown.get("grader_score", 0.0),
 
196
  )
197
  self._state.done = done
198
 
199
- # ── Build next observation ────────────────────────────────
200
  obs = self._build_observation()
201
 
202
- # ── Build info dict ───────────────────────────────────────
203
  info = {
204
- "step_count": self._state.step_count,
205
- "total_reward": self._state.total_reward,
206
- "hints_used": self._state.hints_used,
207
- "action_counts": self._state.action_counts,
208
- "task_id": self._state.task_id,
209
- "difficulty": self._state.difficulty.value if self._state.difficulty else None,
 
 
 
210
  }
211
  if done:
212
  info["episode_summary"] = {
213
- "total_steps": self._state.step_count,
214
- "total_reward": self._state.total_reward,
215
- "hints_used": self._state.hints_used,
216
- "duration_sec": round(time.time() - (self._started_at or time.time()), 2),
 
 
 
 
217
  }
218
 
219
- # Normalize reward to strictly (0, 1) exclusive for validator compliance
220
  normalized_score = max(0.001, min(0.999, (reward.score + 1.0) / 2.0))
221
  reward = Reward(
222
  score=normalized_score,
@@ -231,13 +287,6 @@ class SQLDebuggerEnvironment:
231
  # ─────────────────────────────────────────────
232
 
233
  def state(self) -> EpisodeState:
234
- """
235
- Returns the full current state at any point.
236
- Must be JSON-serializable. Must always reflect latest step.
237
-
238
- Edge case: state() called before reset() β†’ returns default empty state.
239
- Never crashes.
240
- """
241
  return self._state
242
 
243
  # ─────────────────────────────────────────────
@@ -245,13 +294,9 @@ class SQLDebuggerEnvironment:
245
  # ─────────────────────────────────────────────
246
 
247
  def _build_observation(self) -> Observation:
248
- """
249
- Builds the current Observation from internal state.
250
- Injects hint into context if one was just requested.
251
- CRITICAL: Never leaks fixed_query (ground truth) to agent.
252
- """
253
  if self._current_task is None:
254
- # Fallback safe observation
255
  return Observation(
256
  task_id = "none",
257
  task_description = "No task loaded. Call reset() first.",
@@ -264,14 +309,33 @@ class SQLDebuggerEnvironment:
264
  metadata = {}
265
  )
266
 
 
267
  context = task_manager.build_observation_context(self._current_task)
268
 
269
- # Inject hint if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  if "_last_hint" in self._current_task:
271
  context["last_hint"] = self._current_task["_last_hint"]
272
 
273
- # Add step progress info
274
- context["steps_remaining"] = MAX_STEPS - self._state.step_count
275
  context["total_reward_so_far"] = self._state.total_reward
276
 
277
  return Observation(
@@ -284,10 +348,11 @@ class SQLDebuggerEnvironment:
284
  hints_used = self._state.hints_used,
285
  previous_actions = self._state.previous_actions.copy(),
286
  metadata = {
287
- "category": self._current_task.get("category", ""),
288
- "estimated_steps": self._current_task.get("estimated_fix_steps", 5),
289
- "total_reward": self._state.total_reward,
290
- "action_counts": self._state.action_counts,
 
291
  }
292
  )
293
 
@@ -296,4 +361,4 @@ class SQLDebuggerEnvironment:
296
  # SINGLETON INSTANCE (used by FastAPI)
297
  # ─────────────────────────────────────────────
298
 
299
- environment = SQLDebuggerEnvironment()
 
1
+ """
2
+ env/environment.py β€” SQL Database Engineer Agent (SDEA)
3
+ Round 2: Long-horizon DB optimization environment.
4
+ Agent manages a simulated production database over 50 steps.
5
+ """
6
+
7
  import time
8
  import random
9
  from typing import Optional
 
15
  )
16
  from env.tasks import task_manager
17
  from env.reward import compute_reward, is_done, MAX_STEPS
18
+ from env.db_simulator import DatabaseSimulator
19
 
20
 
21
  class SQLDebuggerEnvironment:
22
  """
23
+ OpenEnv-compliant SQL Database Engineer Agent Environment.
24
+
25
+ Round 2 evolution:
26
+ - 50-step long-horizon episodes (up from 20)
27
+ - 10 action types including DB-specific actions
28
+ - DatabaseSimulator tracks real performance score 0-100
29
+ - Milestone bonuses at 25%/50%/75% improvement
30
+ - Backward compatible with Round 1 actions
 
 
 
 
 
31
  """
32
 
33
  def __init__(self):
34
+ self._state = EpisodeState()
35
+ self._current_task = None
36
+ self._started_at = None
37
+ self._db_sim: Optional[DatabaseSimulator] = None
38
+ self._milestones_earned: set = set()
39
+ self._baseline_score: float = 0.0
40
 
41
  # ─────────────────────────────────────────────
42
  # reset() β†’ Observation
 
44
 
45
  def reset(self, difficulty: Optional[str] = None, task_id: Optional[str] = None) -> Observation:
46
  """
47
+ Starts a fresh episode. Clears ALL state.
48
+ Loads scenario and initializes DatabaseSimulator.
 
 
 
 
 
 
49
  """
50
 
51
  # ── Resolve difficulty ────────────────────────────────────
 
53
  try:
54
  diff_enum = DifficultyLevel(difficulty.lower())
55
  except ValueError:
 
56
  diff_enum = random.choice(list(DifficultyLevel))
57
  else:
58
  diff_enum = random.choice(list(DifficultyLevel))
 
63
  except Exception as e:
64
  raise ValueError(f"Failed to load task: {str(e)}")
65
 
66
+ # ── Initialize DatabaseSimulator ──────────────────────────
67
+ # Only initialize for Round 2 scenarios (have 'tables' key)
68
+ if "tables" in task and "slow_queries" in task:
69
+ self._db_sim = DatabaseSimulator(task)
70
+ self._baseline_score = self._db_sim.get_performance_score()
71
+ else:
72
+ # Round 1 task β€” no DB simulator needed
73
+ self._db_sim = None
74
+ self._baseline_score = 0.0
75
+ self._milestones_earned = set()
76
+
77
+ # ── Reset episode state ───────────────────────────────────
78
  self._current_task = task
79
  self._started_at = time.time()
80
  self._state = EpisodeState(
 
85
  done = False,
86
  hints_used = 0,
87
  previous_actions = [],
88
+ action_counts = {
89
+ "_baseline_score": self._baseline_score,
90
+ "_target_score": task.get("target_score", 85.0),
91
+ "_milestones": [],
92
+ "_perf_history": [self._baseline_score],
93
+ "_best_score": self._baseline_score,
94
+ },
95
  started_at = self._started_at,
96
  last_reward = 0.0,
97
  initialized = True,
98
  )
99
 
100
+ return self._build_observation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # ─────────────────────────────────────────────
103
+ # step() β†’ StepResponse
104
  # ─────────────────────────────────────────────
105
 
106
  def step(self, action: Optional[Action]) -> StepResponse:
107
  """
108
+ Processes an action, updates DB simulator, computes reward.
109
+ Handles all Round 2 DB engineering actions.
 
 
 
 
 
 
 
 
 
110
  """
111
 
112
  # ── Auto-reset if not initialized ────────────────────────
113
  if not self._state.initialized or self._current_task is None:
114
  obs = self.reset()
115
  return StepResponse(
116
+ observation = obs,
117
+ reward = Reward(score=0.5, breakdown={"auto_reset": True}, feedback="Environment auto-reset."),
118
+ done = False,
119
+ info = {"auto_reset": True}
120
  )
121
 
122
  # ── Episode already done ──────────────────────────────────
123
  if self._state.done:
124
  obs = self._build_observation()
125
  return StepResponse(
126
+ observation = obs,
127
+ reward = Reward(score=0.5, breakdown={"episode_done": True}, feedback="Episode finished. Call reset()."),
128
+ done = True,
129
+ info = {"episode_done": True, "total_reward": self._state.total_reward}
130
  )
131
 
132
+ # ── Handle null action ────────────────────────────────────
133
  if action is None or action.payload is None:
134
  self._state.step_count += 1
135
+ obs = self._build_observation()
136
+ reward = Reward(score=0.001, breakdown={"invalid_action": 0.001}, feedback="Null action.")
137
+ done = self._state.step_count >= MAX_STEPS
 
 
 
 
 
 
138
  self._state.done = done
139
  return StepResponse(observation=obs, reward=reward, done=done, info={"error": "null_action"})
140
 
141
+ action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
142
+ action_type_enum = action.action_type
 
 
 
143
 
144
  # ── Update step count ───────────────────────────���─────────
145
  self._state.step_count += 1
146
  self._state.previous_actions.append(action_type_val)
147
+ self._state.action_counts[action_type_val] = \
148
+ self._state.action_counts.get(action_type_val, 0) + 1
149
 
150
+ # ── Handle hints ──────────────────────────────────────────
151
+ if action_type_enum == ActionType.REQUEST_HINT:
152
  self._state.hints_used += 1
 
153
  hint_text = task_manager.get_hint(self._current_task, self._state.hints_used)
154
  self._current_task["_last_hint"] = hint_text
155
 
156
+ # ── Apply DB action and get delta ─────────────────────────
157
+ db_delta = 0.0
158
+ current_score = self._baseline_score
159
+ action_info = {}
160
+
161
+ if self._db_sim is not None:
162
+ payload = action.payload or {}
163
+
164
+ if action_type_enum == ActionType.INSPECT_QUERY:
165
+ qid = payload.get("query_id", "q1")
166
+ action_info = self._db_sim.inspect_query(qid)
167
+ self._current_task["_last_inspect"] = action_info
168
+ # No score change β€” investigation action
169
+
170
+ elif action_type_enum == ActionType.ANALYZE_INDEXES:
171
+ table = payload.get("table", "")
172
+ action_info = self._db_sim.analyze_indexes(table)
173
+ self._current_task["_last_analysis"] = action_info
174
+
175
+ elif action_type_enum == ActionType.CREATE_INDEX:
176
+ result = self._db_sim.apply_action("create_index", payload)
177
+ db_delta = result["delta"]
178
+ action_info = result
179
+
180
+ elif action_type_enum == ActionType.REWRITE_QUERY:
181
+ result = self._db_sim.apply_action("rewrite_query", payload)
182
+ db_delta = result["delta"]
183
+ action_info = result
184
+
185
+ elif action_type_enum == ActionType.ADD_COLUMN:
186
+ result = self._db_sim.apply_action("add_column", payload)
187
+ db_delta = result["delta"]
188
+ action_info = result
189
+
190
+ elif action_type_enum == ActionType.DROP_INDEX:
191
+ result = self._db_sim.apply_action("drop_index", payload)
192
+ db_delta = result["delta"]
193
+ action_info = result
194
+
195
+ elif action_type_enum == ActionType.PARTITION_TABLE:
196
+ result = self._db_sim.apply_action("partition_table", payload)
197
+ db_delta = result["delta"]
198
+ action_info = result
199
+
200
+ elif action_type_enum == ActionType.ANALYZE_STATS:
201
+ result = self._db_sim.apply_action("analyze_statistics", payload)
202
+ db_delta = result["delta"]
203
+ action_info = result
204
+
205
+ current_score = self._db_sim.get_performance_score()
206
+
207
+ # Update tracking in action_counts dict (used by /progress)
208
+ perf_history = self._state.action_counts.get("_perf_history", [])
209
+ perf_history.append(current_score)
210
+ self._state.action_counts["_perf_history"] = perf_history
211
+ self._state.action_counts["_best_score"] = self._db_sim.best_score
212
+
213
+ # ── Compute reward ────────────────────────────────────────
214
  reward = compute_reward(
215
+ action = action,
216
+ task_id = self._state.task_id,
217
+ difficulty = self._state.difficulty,
218
+ step_count = self._state.step_count,
219
+ previous_actions = self._state.previous_actions[:-1],
220
+ hints_used = self._state.hints_used,
221
+ estimated_steps = self._current_task.get("estimated_fix_steps", MAX_STEPS),
222
+ action_counts = self._state.action_counts,
223
+ db_delta = db_delta,
224
+ baseline_score = self._baseline_score,
225
+ current_score = current_score,
226
+ milestones_earned = self._milestones_earned,
227
  )
228
 
229
+ # Update milestone tracking
230
+ self._state.action_counts["_milestones"] = list(self._milestones_earned)
231
+
232
  # ── Update cumulative reward ──────────────────────────────
233
  self._state.last_reward = reward.score
234
  self._state.total_reward = round(self._state.total_reward + reward.score, 4)
235
 
236
+ # ── Check done ────────────────────────────────────────────
237
+ target_reached = (
238
+ self._db_sim.is_target_reached() if self._db_sim else False
239
+ )
240
  done = is_done(
241
+ action_type = action_type_enum,
242
+ step_count = self._state.step_count,
243
+ grader_score = reward.breakdown.get("grader_score", 0.0),
244
+ target_reached = target_reached,
245
  )
246
  self._state.done = done
247
 
248
+ # ── Build observation ─────────────────────────────────────
249
  obs = self._build_observation()
250
 
251
+ # ── Info dict ─────────────────────────────────────────────
252
  info = {
253
+ "step_count": self._state.step_count,
254
+ "total_reward": self._state.total_reward,
255
+ "hints_used": self._state.hints_used,
256
+ "task_id": self._state.task_id,
257
+ "difficulty": self._state.difficulty.value if self._state.difficulty else None,
258
+ "performance_score": current_score,
259
+ "db_delta": db_delta,
260
+ "milestones": list(self._milestones_earned),
261
+ "action_result": action_info,
262
  }
263
  if done:
264
  info["episode_summary"] = {
265
+ "total_steps": self._state.step_count,
266
+ "total_reward": self._state.total_reward,
267
+ "hints_used": self._state.hints_used,
268
+ "duration_sec": round(time.time() - (self._started_at or time.time()), 2),
269
+ "final_score": current_score,
270
+ "baseline_score": self._baseline_score,
271
+ "improvement": round(current_score - self._baseline_score, 2),
272
+ "milestones_earned": list(self._milestones_earned),
273
  }
274
 
275
+ # Normalize reward for validator compliance
276
  normalized_score = max(0.001, min(0.999, (reward.score + 1.0) / 2.0))
277
  reward = Reward(
278
  score=normalized_score,
 
287
  # ─────────────────────────────────────────────
288
 
289
  def state(self) -> EpisodeState:
 
 
 
 
 
 
 
290
  return self._state
291
 
292
  # ─────────────────────────────────────────────
 
294
  # ─────────────────────────────────────────────
295
 
296
  def _build_observation(self) -> Observation:
297
+ """Builds Observation from current state + DB simulator state."""
298
+
 
 
 
299
  if self._current_task is None:
 
300
  return Observation(
301
  task_id = "none",
302
  task_description = "No task loaded. Call reset() first.",
 
309
  metadata = {}
310
  )
311
 
312
+ # Base context from task
313
  context = task_manager.build_observation_context(self._current_task)
314
 
315
+ # Inject DB simulator state
316
+ if self._db_sim is not None:
317
+ db_state = self._db_sim.get_current_state()
318
+ context.update({
319
+ "performance_score": db_state["performance_score"],
320
+ "target_score": db_state["target_score"],
321
+ "baseline_score": db_state["baseline_score"],
322
+ "tables": db_state["tables"],
323
+ "slow_queries": db_state["slow_queries"],
324
+ "indexes": db_state["indexes"],
325
+ "improvement_history": db_state["history"],
326
+ "best_score": db_state["best_score"],
327
+ "milestones_earned": list(self._milestones_earned),
328
+ })
329
+
330
+ # Inject last action result if available
331
+ if "_last_inspect" in self._current_task:
332
+ context["last_inspect_result"] = self._current_task["_last_inspect"]
333
+ if "_last_analysis" in self._current_task:
334
+ context["last_analysis_result"] = self._current_task["_last_analysis"]
335
  if "_last_hint" in self._current_task:
336
  context["last_hint"] = self._current_task["_last_hint"]
337
 
338
+ context["steps_remaining"] = MAX_STEPS - self._state.step_count
 
339
  context["total_reward_so_far"] = self._state.total_reward
340
 
341
  return Observation(
 
348
  hints_used = self._state.hints_used,
349
  previous_actions = self._state.previous_actions.copy(),
350
  metadata = {
351
+ "category": self._current_task.get("category", ""),
352
+ "baseline_score": self._baseline_score,
353
+ "target_score": self._current_task.get("target_score", 85.0),
354
+ "total_reward": self._state.total_reward,
355
+ "milestones": list(self._milestones_earned),
356
  }
357
  )
358
 
 
361
  # SINGLETON INSTANCE (used by FastAPI)
362
  # ─────────────────────────────────────────────
363
 
364
+ environment = SQLDebuggerEnvironment()
env/models.py CHANGED
@@ -4,7 +4,9 @@ from enum import Enum
4
  import time
5
 
6
 
 
7
  # ENUMS
 
8
 
9
  class DifficultyLevel(str, Enum):
10
  EASY = "easy"
@@ -13,42 +15,57 @@ class DifficultyLevel(str, Enum):
13
 
14
 
15
  class ActionType(str, Enum):
16
- IDENTIFY_ERROR = "identify_error"
17
- PROPOSE_FIX = "propose_fix"
18
- SUBMIT_ANSWER = "submit_answer"
19
- REQUEST_HINT = "request_hint"
20
- EXPLAIN_ISSUE = "explain_issue"
21
- OPTIMIZE_QUERY = "optimize_query"
22
-
23
- # CORE MODELS
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class Observation(BaseModel):
27
- task_id: str = Field(..., description="Unique task identifier")
28
- task_description: str = Field(..., description="What the agent must do")
29
- current_context: dict = Field(..., description="What the agent currently sees")
30
- step_count: int = Field(default=0, ge=0, description="Steps taken so far")
31
  difficulty: DifficultyLevel = Field(..., description="Task difficulty level")
32
- max_steps: int = Field(default=20, description="Maximum steps allowed")
33
- hints_used: int = Field(default=0, description="Number of hints used")
34
- previous_actions: list[str] = Field(default_factory=list, description="History of action types taken")
35
- metadata: dict = Field(default_factory=dict, description="Extra task metadata")
36
 
37
  model_config = {"json_schema_extra": {
38
  "example": {
39
- "task_id": "easy_001",
40
- "task_description": "Fix the SQL syntax error in the query below.",
41
  "current_context": {
42
- "buggy_query": "SELECT id, name FROM users WHERE id = 1 AND",
43
- "error_message": "SyntaxError: unexpected end of input",
44
- "database_schema": "users(id INT, name VARCHAR, email VARCHAR)"
 
45
  },
46
  "step_count": 0,
47
  "difficulty": "easy",
48
- "max_steps": 20,
49
  "hints_used": 0,
50
  "previous_actions": [],
51
- "metadata": {"category": "syntax", "estimated_fix_steps": 2}
52
  }
53
  }}
54
 
@@ -67,7 +84,6 @@ class Action(BaseModel):
67
  @field_validator("payload")
68
  @classmethod
69
  def truncate_long_strings(cls, v):
70
- # Edge case: extremely long agent output β€” truncate gracefully
71
  def truncate(obj, max_len=5000):
72
  if isinstance(obj, str) and len(obj) > max_len:
73
  return obj[:max_len] + "...[truncated]"
@@ -78,12 +94,10 @@ class Action(BaseModel):
78
 
79
  model_config = {"json_schema_extra": {
80
  "example": {
81
- "action_type": "submit_answer",
82
  "payload": {
83
- "fixed_query": "SELECT id, name FROM users WHERE id = 1",
84
- "explanation": "Removed the trailing AND which caused a syntax error",
85
- "error_type": "syntax",
86
- "confidence": 0.95
87
  }
88
  }
89
  }}
@@ -103,49 +117,53 @@ class Reward(BaseModel):
103
  "example": {
104
  "score": 0.75,
105
  "breakdown": {
106
- "correct_answer": 0.5,
107
- "explanation": 0.2,
108
- "confidence": 0.05,
109
- "step_efficiency": 0.0
110
  },
111
- "feedback": "Correct fix applied. Good explanation provided. Minor efficiency penalty."
112
  }
113
  }}
114
 
115
 
 
116
  # EPISODE STATE (used by state() endpoint)
 
117
 
118
  class EpisodeState(BaseModel):
119
- task_id: Optional[str] = Field(default=None)
120
  difficulty: Optional[DifficultyLevel] = Field(default=None)
121
  step_count: int = Field(default=0)
122
  total_reward: float = Field(default=0.0)
123
  done: bool = Field(default=False)
124
  hints_used: int = Field(default=0)
125
  previous_actions: list[str] = Field(default_factory=list)
126
- action_counts: dict[str, int] = Field(default_factory=dict)
127
  started_at: Optional[float] = Field(default=None)
128
  last_reward: float = Field(default=0.0)
129
  initialized: bool = Field(default=False)
130
 
131
  model_config = {"json_schema_extra": {
132
  "example": {
133
- "task_id": "medium_002",
134
- "difficulty": "medium",
135
  "step_count": 3,
136
- "total_reward": 0.45,
137
  "done": False,
138
- "hints_used": 1,
139
- "previous_actions": ["identify_error", "request_hint", "propose_fix"],
140
- "action_counts": {"identify_error": 1, "request_hint": 1, "propose_fix": 1},
141
  "started_at": 1700000000.0,
142
- "last_reward": 0.25,
143
  "initialized": True
144
  }
145
  }}
146
 
147
 
 
148
  # API REQUEST / RESPONSE WRAPPERS
 
149
 
150
  class StepResponse(BaseModel):
151
  observation: Observation
@@ -157,15 +175,15 @@ class ResetResponse(BaseModel):
157
  observation: Observation
158
 
159
  class TaskInfo(BaseModel):
160
- id: str
161
- difficulty: DifficultyLevel
162
- description: str
163
- action_schema: dict # REQUIRED by validator β€” field definitions not just names
164
 
165
  class TaskListResponse(BaseModel):
166
- tasks: list[TaskInfo]
167
- total: int
168
- action_types: list[str]
169
 
170
  class BaselineResult(BaseModel):
171
  task_id: str
@@ -180,7 +198,7 @@ class BaselineResult(BaseModel):
180
  return max(0.001, min(0.999, round(float(v), 4)))
181
 
182
  class BaselineResponse(BaseModel):
183
- results: list[BaselineResult]
184
  average_score: float
185
  completed_at: float = Field(default_factory=time.time)
186
 
@@ -201,13 +219,30 @@ class GraderResponse(BaseModel):
201
 
202
  model_config = {"json_schema_extra": {
203
  "example": {
204
- "score": 0.75,
205
- "feedback": "Correct fix applied.",
206
- "breakdown": {"fix_correctness": 0.5, "explanation": 0.15, "confidence": 0.05}
207
  }
208
  }}
209
 
210
  class HealthResponse(BaseModel):
211
- status: str = "ok"
212
- version: str = "1.0.0"
213
- uptime: float = Field(default_factory=time.time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import time
5
 
6
 
7
+ # ─────────────────────────────────────────────
8
  # ENUMS
9
+ # ─────────────────────────────────────────────
10
 
11
  class DifficultyLevel(str, Enum):
12
  EASY = "easy"
 
15
 
16
 
17
  class ActionType(str, Enum):
18
+ # ── Round 1 actions (keep β€” backward compatible) ──
19
+ IDENTIFY_ERROR = "identify_error"
20
+ PROPOSE_FIX = "propose_fix"
21
+ SUBMIT_ANSWER = "submit_answer"
22
+ REQUEST_HINT = "request_hint"
23
+ EXPLAIN_ISSUE = "explain_issue"
24
+ OPTIMIZE_QUERY = "optimize_query"
25
+
26
+ # ── Round 2 new actions ──
27
+ INSPECT_QUERY = "inspect_query"
28
+ ANALYZE_INDEXES = "analyze_indexes"
29
+ CREATE_INDEX = "create_index"
30
+ REWRITE_QUERY = "rewrite_query"
31
+ ADD_COLUMN = "add_column"
32
+ DROP_INDEX = "drop_index"
33
+ PARTITION_TABLE = "partition_table"
34
+ ANALYZE_STATS = "analyze_statistics"
35
+ SUBMIT_REPORT = "submit_report"
36
+
37
+
38
+ # ─────────────────────────────────────────────
39
+ # CORE MODELS
40
+ # ─────────────────────────────────────────────
41
 
42
  class Observation(BaseModel):
43
+ task_id: str = Field(..., description="Unique task identifier")
44
+ task_description: str = Field(..., description="What the agent must do")
45
+ current_context: dict = Field(..., description="What the agent currently sees")
46
+ step_count: int = Field(default=0, ge=0, description="Steps taken so far")
47
  difficulty: DifficultyLevel = Field(..., description="Task difficulty level")
48
+ max_steps: int = Field(default=50, description="Maximum steps allowed")
49
+ hints_used: int = Field(default=0, description="Number of hints used")
50
+ previous_actions: list[str] = Field(default_factory=list, description="History of action types taken")
51
+ metadata: dict = Field(default_factory=dict, description="Extra task metadata")
52
 
53
  model_config = {"json_schema_extra": {
54
  "example": {
55
+ "task_id": "easy_s001",
56
+ "task_description": "Optimize a slow user lookup query on 10K users table.",
57
  "current_context": {
58
+ "tables": [{"name": "users", "rows": 10000, "indexes": ["PRIMARY"]}],
59
+ "slow_queries": [{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000}],
60
+ "performance_score": 8.0,
61
+ "target_score": 80.0
62
  },
63
  "step_count": 0,
64
  "difficulty": "easy",
65
+ "max_steps": 50,
66
  "hints_used": 0,
67
  "previous_actions": [],
68
+ "metadata": {"scenario_id": "easy_s001", "baseline_score": 8.0}
69
  }
70
  }}
71
 
 
84
  @field_validator("payload")
85
  @classmethod
86
  def truncate_long_strings(cls, v):
 
87
  def truncate(obj, max_len=5000):
88
  if isinstance(obj, str) and len(obj) > max_len:
89
  return obj[:max_len] + "...[truncated]"
 
94
 
95
  model_config = {"json_schema_extra": {
96
  "example": {
97
+ "action_type": "create_index",
98
  "payload": {
99
+ "table": "users",
100
+ "columns": ["email"]
 
 
101
  }
102
  }
103
  }}
 
117
  "example": {
118
  "score": 0.75,
119
  "breakdown": {
120
+ "step_reward": 0.05,
121
+ "delta_reward": 0.40,
122
+ "milestone_bonus": 0.15,
123
+ "total": 0.60
124
  },
125
+ "feedback": "Index created. Performance improved 55%. Milestone bonus earned!"
126
  }
127
  }}
128
 
129
 
130
+ # ─────────────────────────────────────────────
131
  # EPISODE STATE (used by state() endpoint)
132
+ # ─────────────────────────────────────────────
133
 
134
  class EpisodeState(BaseModel):
135
+ task_id: Optional[str] = Field(default=None)
136
  difficulty: Optional[DifficultyLevel] = Field(default=None)
137
  step_count: int = Field(default=0)
138
  total_reward: float = Field(default=0.0)
139
  done: bool = Field(default=False)
140
  hints_used: int = Field(default=0)
141
  previous_actions: list[str] = Field(default_factory=list)
142
+ action_counts: dict[str, Any] = Field(default_factory=dict)
143
  started_at: Optional[float] = Field(default=None)
144
  last_reward: float = Field(default=0.0)
145
  initialized: bool = Field(default=False)
146
 
147
  model_config = {"json_schema_extra": {
148
  "example": {
149
+ "task_id": "easy_s001",
150
+ "difficulty": "easy",
151
  "step_count": 3,
152
+ "total_reward": 0.65,
153
  "done": False,
154
+ "hints_used": 0,
155
+ "previous_actions": ["inspect_query", "analyze_indexes", "create_index"],
156
+ "action_counts": {"inspect_query": 1, "analyze_indexes": 1, "create_index": 1},
157
  "started_at": 1700000000.0,
158
+ "last_reward": 0.45,
159
  "initialized": True
160
  }
161
  }}
162
 
163
 
164
+ # ─────────────────────────────────────────────
165
  # API REQUEST / RESPONSE WRAPPERS
166
+ # ─────────────────────────────────────────────
167
 
168
  class StepResponse(BaseModel):
169
  observation: Observation
 
175
  observation: Observation
176
 
177
  class TaskInfo(BaseModel):
178
+ id: str
179
+ difficulty: DifficultyLevel
180
+ description: str
181
+ action_schema: dict
182
 
183
  class TaskListResponse(BaseModel):
184
+ tasks: list[TaskInfo]
185
+ total: int
186
+ action_types: list[str]
187
 
188
  class BaselineResult(BaseModel):
189
  task_id: str
 
198
  return max(0.001, min(0.999, round(float(v), 4)))
199
 
200
  class BaselineResponse(BaseModel):
201
+ results: list[BaselineResult]
202
  average_score: float
203
  completed_at: float = Field(default_factory=time.time)
204
 
 
219
 
220
  model_config = {"json_schema_extra": {
221
  "example": {
222
+ "score": 0.82,
223
+ "feedback": "Performance improved from 12.5 to 85.0. Excellent optimization!",
224
+ "breakdown": {"perf_improvement": 0.60, "step_efficiency": 0.12, "index_quality": 0.10}
225
  }
226
  }}
227
 
228
  class HealthResponse(BaseModel):
229
+ status: str = "ok"
230
+ version: str = "2.0.0"
231
+ uptime: float = Field(default_factory=time.time)
232
+
233
+
234
+ # ─────────────────────────────────────────────
235
+ # ROUND 2 β€” PROGRESS RESPONSE
236
+ # ─────────────────────────────────────────────
237
+
238
+ class ProgressResponse(BaseModel):
239
+ scenario_id: Optional[str] = Field(default=None)
240
+ performance_score: float = Field(default=0.0, description="Current DB performance score 0-100")
241
+ baseline_score: float = Field(default=0.0, description="Starting score this episode")
242
+ target_score: float = Field(default=85.0, description="Score needed to succeed")
243
+ improvement_history: list[float] = Field(default_factory=list)
244
+ milestones_earned: list[float] = Field(default_factory=list)
245
+ best_score: float = Field(default=0.0)
246
+ steps_used: int = Field(default=0)
247
+ budget_remaining: int = Field(default=50)
248
+ total_reward: float = Field(default=0.0)
env/reward.py CHANGED
@@ -1,41 +1,94 @@
1
  from env.models import Action, Reward, DifficultyLevel, ActionType
2
  from env.graders import grade
3
 
 
4
  # CONSTANTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- MAX_STEPS = 20
7
- HINT_PENALTY = -0.05 # Per hint requested
8
- LOOP_PENALTY = -0.05 # Same action 3+ times in a row
9
- INVALID_PENALTY = -0.10 # Null / malformed action
10
- STEP_EFFICIENCY_BONUS = 0.10 # Bonus for solving in fewer steps than estimated
11
-
12
- # Dense reward per action type (before grader score)
13
  STEP_REWARDS = {
14
- ActionType.IDENTIFY_ERROR: 0.15, # Rewarded for diagnosing
15
- ActionType.PROPOSE_FIX: 0.25, # Rewarded for attempting fix
16
- ActionType.SUBMIT_ANSWER: 0.00, # Final score comes from grader
17
- ActionType.REQUEST_HINT: 0.00, # No reward, only penalty
18
- ActionType.EXPLAIN_ISSUE: 0.10, # Rewarded for explaining
19
- ActionType.OPTIMIZE_QUERY: 0.20, # Rewarded for optimization attempt
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
 
 
 
 
 
 
 
22
 
23
- # LOOP DETECTOR
24
 
 
 
 
25
 
26
- def _detect_loop(previous_actions: list[str], current_action: str) -> bool:
 
 
 
 
27
  """
28
- Returns True if the agent has submitted the same action type
29
- 3 or more times in a row β€” indicating a stuck loop.
30
  """
31
- if len(previous_actions) < 2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return False
33
- last_two = previous_actions[-2:]
34
- return all(a == current_action for a in last_two)
35
 
36
 
37
  def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
38
- """Count how many times the current action has been repeated consecutively."""
39
  count = 1
40
  for a in reversed(previous_actions):
41
  if a == current_action:
@@ -45,24 +98,22 @@ def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
45
  return count
46
 
47
 
 
48
  # EFFICIENCY BONUS
49
-
50
-
51
- def _efficiency_bonus(step_count: int, estimated_steps: int) -> float:
52
- """
53
- Bonus reward if agent solves faster than estimated.
54
- Encourages efficient reasoning, not just correct answers.
55
- """
56
- if step_count <= 0 or estimated_steps <= 0:
57
- return 0.0
58
- if step_count <= estimated_steps:
59
- ratio = step_count / estimated_steps
60
- # More bonus the faster β€” scales from 0.10 down to 0.0
61
- return round(STEP_EFFICIENCY_BONUS * (1.0 - ratio + 0.1), 4)
62
  return 0.0
63
 
64
 
 
65
  # MAIN REWARD FUNCTION
 
66
 
67
  def compute_reward(
68
  action: Action,
@@ -73,25 +124,33 @@ def compute_reward(
73
  hints_used: int,
74
  estimated_steps: int,
75
  action_counts: dict[str, int],
 
 
 
 
 
76
  ) -> Reward:
77
  """
78
- Computes a DENSE reward signal for every step.
79
- Never returns 0.0 for all steps β€” reward varies at each step.
80
-
81
- Dense reward components:
82
- 1. Step reward β€” small reward just for taking valid action
83
- 2. Grader score β€” full grader score on submit_answer / optimize_query
84
- 3. Loop penalty β€” repeated same action 3+ times
85
- 4. Hint penalty β€” accumulated hint cost
86
- 5. Efficiency bonus β€” solved faster than estimated steps
87
- 6. Invalid penalty β€” null / malformed action
88
-
89
- Score is always clamped to [-1.0, 1.0].
90
  """
91
 
92
- breakdown = {}
 
 
 
93
  feedback_parts = []
94
- final_score = 0.0
95
 
96
  # ── Edge case: null action ────────────────────────────────────
97
  if action is None or action.payload is None:
@@ -100,105 +159,155 @@ def compute_reward(
100
  breakdown={"invalid_action": 0.001},
101
  feedback="Invalid or null action received."
102
  )
103
- action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
 
104
  action_type_enum = action.action_type
105
 
106
- # ── 1. Step reward (dense signal) ────────────────────────────
107
  step_reward = STEP_REWARDS.get(action_type_enum, 0.05)
108
  breakdown["step_reward"] = round(step_reward, 4)
109
  final_score += step_reward
110
  if step_reward > 0:
111
- feedback_parts.append(f"Action '{action_type_val}' rewarded +{step_reward}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # ── 2. Grader score for terminal actions ──────────────────────
114
  grader_score = 0.0
115
- is_terminal = action_type_enum in (ActionType.SUBMIT_ANSWER, ActionType.OPTIMIZE_QUERY)
116
 
117
- if is_terminal:
118
  raw_score, grader_breakdown, grader_feedback = grade(action, task_id)
119
  grader_score = raw_score
120
- breakdown["grader_score"] = round(grader_score, 4)
121
  breakdown["grader_breakdown"] = grader_breakdown
122
  final_score += grader_score
123
  feedback_parts.append(grader_feedback)
124
 
125
- # Efficiency bonus β€” only on correct terminal action
126
  if grader_score >= 0.5:
127
- eff_bonus = _efficiency_bonus(step_count, estimated_steps)
128
  if eff_bonus > 0:
129
  final_score += eff_bonus
130
  breakdown["efficiency_bonus"] = round(eff_bonus, 4)
131
- feedback_parts.append(f"Efficiency bonus +{eff_bonus} for solving in {step_count} steps.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  elif action_type_enum == ActionType.PROPOSE_FIX:
134
- # Partial grader score for propose_fix β€” encourages iterative improvement
135
  raw_score, grader_breakdown, _ = grade(action, task_id)
136
- partial = round(raw_score * 0.4, 4) # 40% of full grader score
137
- grader_score = partial
138
  breakdown["partial_grader_score"] = partial
139
  final_score += partial
140
- if partial > 0:
141
- feedback_parts.append(f"Partial fix credit +{partial}.")
142
 
143
  elif action_type_enum == ActionType.IDENTIFY_ERROR:
144
- # Small grader check on error identification
145
  raw_score, _, _ = grade(action, task_id)
146
- partial = round(raw_score * 0.2, 4) # 20% for identification step
147
  breakdown["identification_score"] = partial
148
  final_score += partial
149
 
150
- # ── 3. Loop penalty ───────────────────────────────────────────
151
  if _detect_loop(previous_actions, action_type_val):
152
  consecutive = _count_consecutive(previous_actions, action_type_val)
153
- loop_pen = LOOP_PENALTY * min(consecutive - 2, 3) # Cap at 3x penalty
154
  final_score += loop_pen
155
  breakdown["loop_penalty"] = round(loop_pen, 4)
156
- feedback_parts.append(f"Loop detected ({consecutive}x same action). Penalty {loop_pen}.")
157
 
158
- # ── 4. Hint penalty ───────────────────────────────────────────
159
  if action_type_enum == ActionType.REQUEST_HINT:
160
- hint_pen = HINT_PENALTY
161
- final_score += hint_pen
162
- breakdown["hint_penalty"] = round(hint_pen, 4)
163
- feedback_parts.append(f"Hint requested. Penalty {hint_pen}.")
164
-
165
- # ── 5. Max steps penalty ──────────────────────────────────────
166
- if step_count >= MAX_STEPS - 1:
167
- final_score += -0.10
168
- breakdown["max_steps_penalty"] = -0.10
169
- feedback_parts.append("Approaching max steps limit. Penalty applied.")
170
-
171
- # ── Clamp to [-1.0, 1.0] ─────────────────────────────────────
172
- # Clamp strictly between 0.001 and 0.999 for validator compliance
 
 
 
 
173
  final_score = round(max(0.001, min(0.999, final_score)), 4)
174
  breakdown["total"] = final_score
175
 
176
  feedback = " ".join(feedback_parts) if feedback_parts else "Step processed."
177
 
178
- return Reward(
179
- score=final_score,
180
- breakdown=breakdown,
181
- feedback=feedback
182
- )
183
 
184
 
 
185
  # EPISODE DONE CONDITION
 
186
 
187
  def is_done(
188
- action_type: ActionType,
189
- step_count: int,
190
- grader_score: float = 0.0,
 
191
  ) -> bool:
192
  """
193
  Episode ends when:
194
- 1. Agent submits final answer (submit_answer / optimize_query)
195
  2. Max steps reached
196
- 3. Perfect score achieved
197
  """
198
- if action_type in (ActionType.SUBMIT_ANSWER, ActionType.OPTIMIZE_QUERY):
199
  return True
200
  if step_count >= MAX_STEPS:
201
  return True
202
  if grader_score >= 1.0:
203
  return True
204
- return False
 
 
 
1
  from env.models import Action, Reward, DifficultyLevel, ActionType
2
  from env.graders import grade
3
 
4
+ # ─────────────────────────────────────────────
5
  # CONSTANTS
6
+ # ─────────────────────────────────────────────
7
+
8
+ MAX_STEPS = 50 # Round 2: long-horizon episodes
9
+ HINT_PENALTY = -0.10 # Per hint requested (increased from Round 1)
10
+ LOOP_PENALTY = -0.08 # Same action on same target 2+ times, no improvement
11
+ INVALID_PENALTY = -0.10 # Null / malformed action
12
+ BACKTRACK_PENALTY = -0.05 # Action makes score worse than previous best
13
+ BUDGET_EXHAUSTION_PEN = -0.15 # Reaching max_steps without submitting report
14
+ EFFICIENCY_BONUS = 0.10 # Solved in < 70% of max_steps
15
+
16
+ # Milestone thresholds: {improvement_fraction: bonus_reward}
17
+ MILESTONE_THRESHOLDS = {
18
+ 0.25: 0.15, # 25% improvement β†’ +0.15 bonus
19
+ 0.50: 0.25, # 50% improvement β†’ +0.25 bonus
20
+ 0.75: 0.40, # 75% improvement β†’ +0.40 bonus
21
+ }
22
 
23
+ # Step rewards for Round 2 actions (dense signal)
 
 
 
 
 
 
24
  STEP_REWARDS = {
25
+ # ── Round 2 actions ──────────────────────────
26
+ ActionType.INSPECT_QUERY: 0.05, # Investigation rewarded
27
+ ActionType.ANALYZE_INDEXES: 0.05, # Investigation rewarded
28
+ ActionType.CREATE_INDEX: 0.10, # Core optimization action
29
+ ActionType.REWRITE_QUERY: 0.15, # High-value rewrite
30
+ ActionType.ADD_COLUMN: 0.08, # Denormalization
31
+ ActionType.DROP_INDEX: 0.05, # Clean up overhead
32
+ ActionType.PARTITION_TABLE: 0.15, # Big structural improvement
33
+ ActionType.ANALYZE_STATS: 0.05, # Maintenance action
34
+ ActionType.SUBMIT_REPORT: 0.00, # Terminal β€” score comes from grader
35
+ ActionType.REQUEST_HINT: 0.00, # No reward, only penalty
36
+ # ── Round 1 backward compat ──────────────────
37
+ ActionType.IDENTIFY_ERROR: 0.15,
38
+ ActionType.PROPOSE_FIX: 0.25,
39
+ ActionType.SUBMIT_ANSWER: 0.00,
40
+ ActionType.EXPLAIN_ISSUE: 0.10,
41
+ ActionType.OPTIMIZE_QUERY: 0.20,
42
  }
43
 
44
+ # Terminal actions that end the episode
45
+ TERMINAL_ACTIONS = {
46
+ ActionType.SUBMIT_ANSWER,
47
+ ActionType.OPTIMIZE_QUERY,
48
+ ActionType.SUBMIT_REPORT,
49
+ }
50
 
 
51
 
52
+ # ─────────────────────────────────────────────
53
+ # MILESTONE TRACKER
54
+ # ─────────────────────────────────────────────
55
 
56
+ def check_milestones(
57
+ baseline_score: float,
58
+ new_score: float,
59
+ earned: set,
60
+ ) -> tuple[float, list[float]]:
61
  """
62
+ Returns (total_bonus, newly_earned_thresholds).
63
+ One-time bonuses β€” each milestone only paid once per episode.
64
  """
65
+ max_possible = max(1.0, 100.0 - baseline_score)
66
+ improvement = (new_score - baseline_score) / max_possible
67
+ bonus = 0.0
68
+ newly_earned = []
69
+
70
+ for threshold, reward in MILESTONE_THRESHOLDS.items():
71
+ if improvement >= threshold and threshold not in earned:
72
+ bonus += reward
73
+ newly_earned.append(threshold)
74
+ earned.add(threshold)
75
+
76
+ return round(bonus, 4), newly_earned
77
+
78
+
79
+ # ─────────────────────────────────────────────
80
+ # LOOP DETECTOR
81
+ # ─────────────────────────────────────────────
82
+
83
+ def _detect_loop(previous_actions: list[str], current_action: str) -> bool:
84
+ """Returns True if agent has done the same action 2+ times in a row."""
85
+ if len(previous_actions) < 1:
86
  return False
87
+ last = previous_actions[-1]
88
+ return last == current_action
89
 
90
 
91
  def _count_consecutive(previous_actions: list[str], current_action: str) -> int:
 
92
  count = 1
93
  for a in reversed(previous_actions):
94
  if a == current_action:
 
98
  return count
99
 
100
 
101
+ # ─────────────────────────────────────────────
102
  # EFFICIENCY BONUS
103
+ # ─────────────────────────────────────────────
104
+
105
+ def _efficiency_bonus(step_count: int, max_steps: int) -> float:
106
+ """Bonus if agent finishes in < 70% of budget."""
107
+ threshold = max_steps * 0.70
108
+ if step_count <= threshold:
109
+ ratio = step_count / max(1, max_steps)
110
+ return round(EFFICIENCY_BONUS * (1.0 - ratio), 4)
 
 
 
 
 
111
  return 0.0
112
 
113
 
114
+ # ─────────────────────────────────────────────
115
  # MAIN REWARD FUNCTION
116
+ # ─────────────────────────────────────────────
117
 
118
  def compute_reward(
119
  action: Action,
 
124
  hints_used: int,
125
  estimated_steps: int,
126
  action_counts: dict[str, int],
127
+ # Round 2 extras (optional β€” backward compatible)
128
+ db_delta: float = 0.0, # Performance score delta from DatabaseSimulator
129
+ baseline_score: float = 0.0, # Scenario baseline score
130
+ current_score: float = 0.0, # Current DB performance score
131
+ milestones_earned: set = None, # Set of already-earned milestone thresholds
132
  ) -> Reward:
133
  """
134
+ Computes dense reward signal for every step.
135
+
136
+ Components:
137
+ 1. Step reward β€” small reward for valid action type
138
+ 2. Delta reward β€” proportional to DB performance improvement (Round 2)
139
+ 3. Milestone bonus β€” one-time bonus at 25%/50%/75% improvement
140
+ 4. Grader score β€” full score on terminal actions (Round 1 compat)
141
+ 5. Loop penalty β€” repeated same action with no improvement
142
+ 6. Hint penalty β€” cost per hint
143
+ 7. Backtrack penalty β€” action made things worse
144
+ 8. Budget penalty β€” approaching max_steps without submitting
145
+ 9. Efficiency bonus β€” solved fast
146
  """
147
 
148
+ if milestones_earned is None:
149
+ milestones_earned = set()
150
+
151
+ breakdown = {}
152
  feedback_parts = []
153
+ final_score = 0.0
154
 
155
  # ── Edge case: null action ────────────────────────────────────
156
  if action is None or action.payload is None:
 
159
  breakdown={"invalid_action": 0.001},
160
  feedback="Invalid or null action received."
161
  )
162
+
163
+ action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type)
164
  action_type_enum = action.action_type
165
 
166
+ # ── 1. Step reward ────────────────────────────────────────────
167
  step_reward = STEP_REWARDS.get(action_type_enum, 0.05)
168
  breakdown["step_reward"] = round(step_reward, 4)
169
  final_score += step_reward
170
  if step_reward > 0:
171
+ feedback_parts.append(f"Action '{action_type_val}' +{step_reward}.")
172
+
173
+ # ── 2. Delta reward (Round 2 DB performance change) ───────────
174
+ if db_delta != 0.0:
175
+ delta_reward = round((db_delta / 100.0) * 0.40, 4)
176
+ delta_reward = max(-0.40, min(0.40, delta_reward))
177
+ breakdown["delta_reward"] = delta_reward
178
+ final_score += delta_reward
179
+ if delta_reward > 0:
180
+ feedback_parts.append(f"DB improved +{db_delta:.1f} pts. Delta reward +{delta_reward}.")
181
+ elif delta_reward < 0:
182
+ feedback_parts.append(f"DB worsened {db_delta:.1f} pts. Penalty {delta_reward}.")
183
+
184
+ # ── 3. Milestone bonuses ──────────────────────────────────────
185
+ if baseline_score > 0 and current_score > 0:
186
+ milestone_bonus, newly_earned = check_milestones(
187
+ baseline_score, current_score, milestones_earned
188
+ )
189
+ if milestone_bonus > 0:
190
+ breakdown["milestone_bonus"] = milestone_bonus
191
+ final_score += milestone_bonus
192
+ pct = int(max(newly_earned) * 100)
193
+ feedback_parts.append(f"🎯 Milestone! {pct}% improvement. Bonus +{milestone_bonus}!")
194
 
195
+ # ── 4. Grader score for terminal actions (Round 1 compat) ─────
196
  grader_score = 0.0
197
+ is_terminal = action_type_enum in TERMINAL_ACTIONS
198
 
199
+ if is_terminal and action_type_enum != ActionType.SUBMIT_REPORT:
200
  raw_score, grader_breakdown, grader_feedback = grade(action, task_id)
201
  grader_score = raw_score
202
+ breakdown["grader_score"] = round(grader_score, 4)
203
  breakdown["grader_breakdown"] = grader_breakdown
204
  final_score += grader_score
205
  feedback_parts.append(grader_feedback)
206
 
 
207
  if grader_score >= 0.5:
208
+ eff_bonus = _efficiency_bonus(step_count, MAX_STEPS)
209
  if eff_bonus > 0:
210
  final_score += eff_bonus
211
  breakdown["efficiency_bonus"] = round(eff_bonus, 4)
212
+ feedback_parts.append(f"Efficiency bonus +{eff_bonus}.")
213
+
214
+ elif is_terminal and action_type_enum == ActionType.SUBMIT_REPORT:
215
+ # Round 2 terminal: compute from DB performance
216
+ if baseline_score > 0 and current_score > 0:
217
+ perf_improvement = (current_score - baseline_score) / max(1.0, 100.0 - baseline_score)
218
+ step_efficiency = 1.0 - (step_count / max(1, MAX_STEPS))
219
+ terminal_score = round(
220
+ (perf_improvement * 0.60) + (step_efficiency * 0.20) + 0.10, 4
221
+ )
222
+ terminal_score = max(0.001, min(0.999, terminal_score))
223
+ breakdown["terminal_score"] = terminal_score
224
+ breakdown["perf_improvement"] = round(perf_improvement, 4)
225
+ breakdown["step_efficiency"] = round(step_efficiency, 4)
226
+ final_score += terminal_score
227
+ feedback_parts.append(
228
+ f"Report submitted. Performance: {baseline_score:.1f} β†’ {current_score:.1f}. "
229
+ f"Terminal score: {terminal_score}."
230
+ )
231
+ # Efficiency bonus on submit_report too
232
+ eff_bonus = _efficiency_bonus(step_count, MAX_STEPS)
233
+ if eff_bonus > 0:
234
+ final_score += eff_bonus
235
+ breakdown["efficiency_bonus"] = round(eff_bonus, 4)
236
+ feedback_parts.append(f"Efficiency bonus +{eff_bonus}.")
237
+ else:
238
+ breakdown["terminal_score"] = 0.10
239
+ final_score += 0.10
240
+ feedback_parts.append("Report submitted.")
241
 
242
  elif action_type_enum == ActionType.PROPOSE_FIX:
 
243
  raw_score, grader_breakdown, _ = grade(action, task_id)
244
+ partial = round(raw_score * 0.4, 4)
 
245
  breakdown["partial_grader_score"] = partial
246
  final_score += partial
 
 
247
 
248
  elif action_type_enum == ActionType.IDENTIFY_ERROR:
 
249
  raw_score, _, _ = grade(action, task_id)
250
+ partial = round(raw_score * 0.2, 4)
251
  breakdown["identification_score"] = partial
252
  final_score += partial
253
 
254
+ # ── 5. Loop penalty ───────────────────────────────────────────
255
  if _detect_loop(previous_actions, action_type_val):
256
  consecutive = _count_consecutive(previous_actions, action_type_val)
257
+ loop_pen = LOOP_PENALTY * min(consecutive - 1, 3)
258
  final_score += loop_pen
259
  breakdown["loop_penalty"] = round(loop_pen, 4)
260
+ feedback_parts.append(f"Loop detected ({consecutive}x). Penalty {loop_pen}.")
261
 
262
+ # ── 6. Hint penalty ───────────────────────────────────────────
263
  if action_type_enum == ActionType.REQUEST_HINT:
264
+ final_score += HINT_PENALTY
265
+ breakdown["hint_penalty"] = HINT_PENALTY
266
+ feedback_parts.append(f"Hint requested. Penalty {HINT_PENALTY}.")
267
+
268
+ # ── 7. Backtrack penalty ──────────────────────────────────────
269
+ if db_delta < -1.0:
270
+ final_score += BACKTRACK_PENALTY
271
+ breakdown["backtrack_penalty"] = BACKTRACK_PENALTY
272
+ feedback_parts.append(f"Performance regressed. Backtrack penalty {BACKTRACK_PENALTY}.")
273
+
274
+ # ── 8. Budget exhaustion penalty ─────────────────────────────
275
+ if step_count >= MAX_STEPS - 2 and not is_terminal:
276
+ final_score += BUDGET_EXHAUSTION_PEN
277
+ breakdown["budget_penalty"] = BUDGET_EXHAUSTION_PEN
278
+ feedback_parts.append("Budget nearly exhausted. Submit report now!")
279
+
280
+ # ── Clamp to (0.001, 0.999) ───────────────────────────────────
281
  final_score = round(max(0.001, min(0.999, final_score)), 4)
282
  breakdown["total"] = final_score
283
 
284
  feedback = " ".join(feedback_parts) if feedback_parts else "Step processed."
285
 
286
+ return Reward(score=final_score, breakdown=breakdown, feedback=feedback)
 
 
 
 
287
 
288
 
289
+ # ─────────────────────────────────────────────
290
  # EPISODE DONE CONDITION
291
+ # ─────────────────────────────────────────────
292
 
293
  def is_done(
294
+ action_type: ActionType,
295
+ step_count: int,
296
+ grader_score: float = 0.0,
297
+ target_reached: bool = False,
298
  ) -> bool:
299
  """
300
  Episode ends when:
301
+ 1. Agent submits report / final answer
302
  2. Max steps reached
303
+ 3. Perfect score / target reached
304
  """
305
+ if action_type in TERMINAL_ACTIONS:
306
  return True
307
  if step_count >= MAX_STEPS:
308
  return True
309
  if grader_score >= 1.0:
310
  return True
311
+ if target_reached:
312
+ return True
313
+ return False
env/tasks.py CHANGED
@@ -3,28 +3,50 @@ import random
3
  from pathlib import Path
4
  from env.models import DifficultyLevel, TaskInfo
5
 
6
- # LOAD DATASETS
 
 
7
 
8
  BASE_DIR = Path(__file__).parent.parent / "dataset"
9
 
 
10
  def _load(filename: str) -> list[dict]:
11
  path = BASE_DIR / filename
12
  with open(path, "r", encoding="utf-8") as f:
13
  return json.load(f)
14
 
 
 
15
  EASY_CASES = _load("easy_cases.json")
16
  MEDIUM_CASES = _load("medium_cases.json")
17
  HARD_CASES = _load("hard_cases.json")
18
 
 
 
 
 
 
 
19
  ALL_CASES: dict[str, list[dict]] = {
20
- DifficultyLevel.EASY: EASY_CASES,
21
- DifficultyLevel.MEDIUM: MEDIUM_CASES,
22
- DifficultyLevel.HARD: HARD_CASES,
23
  }
24
 
 
 
 
 
 
 
 
 
 
25
  # ACTION SCHEMA (required by /tasks validator)
 
26
 
27
  ACTION_SCHEMA = {
 
28
  "identify_error": {
29
  "description": "Identify where and what the error is without fixing it yet",
30
  "payload_fields": {
@@ -36,52 +58,120 @@ ACTION_SCHEMA = {
36
  "propose_fix": {
37
  "description": "Propose a fix without submitting as final answer",
38
  "payload_fields": {
39
- "fixed_query": {"type": "string", "required": True, "description": "The proposed corrected SQL query"},
40
- "change_made": {"type": "string", "required": True, "description": "What specifically was changed"},
41
- "confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"}
42
  }
43
  },
44
  "submit_answer": {
45
  "description": "Submit the final fixed query as the definitive answer",
46
  "payload_fields": {
47
- "fixed_query": {"type": "string", "required": True, "description": "Final corrected SQL query"},
48
- "explanation": {"type": "string", "required": True, "description": "Full explanation of what was wrong and how it was fixed"},
49
- "error_type": {"type": "string", "required": False, "description": "Type: syntax | logic | performance"},
50
- "confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"}
51
  }
52
  },
53
  "request_hint": {
54
- "description": "Request a hint β€” costs 0.05 reward penalty per hint",
55
  "payload_fields": {
56
- "hint_type": {"type": "string", "required": False, "description": "Type of hint wanted: location | error_type | fix_direction"}
57
  }
58
  },
59
  "explain_issue": {
60
- "description": "Explain the issue in detail β€” earns partial credit even without fixing",
61
  "payload_fields": {
62
- "explanation": {"type": "string", "required": True, "description": "Detailed explanation of the SQL problem"},
63
- "impact": {"type": "string", "required": False, "description": "What impact the bug has on query results or performance"},
64
- "root_cause": {"type": "string", "required": False, "description": "Root cause analysis"}
65
  }
66
  },
67
  "optimize_query": {
68
- "description": "Submit an optimized version of the query (used for hard/performance tasks)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  "payload_fields": {
70
- "optimized_query": {"type": "string", "required": True, "description": "The performance-optimized SQL query"},
71
- "optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"},
72
- "expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain description"},
73
- "explanation": {"type": "string", "required": False, "description": "Why this optimization works"},
74
- "confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
75
  }
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  }
78
- # TASK MANAGER
79
 
80
 
 
 
 
 
81
  class TaskManager:
82
  """
83
- Manages task selection, hint generation, and task metadata.
84
- All tasks are loaded from JSON datasets β€” no hardcoded tasks.
 
85
  """
86
 
87
  def __init__(self):
@@ -89,9 +179,8 @@ class TaskManager:
89
 
90
  def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict:
91
  """
92
- Returns a task dict for the given difficulty.
93
- If task_id is provided, returns that specific task.
94
- Otherwise picks randomly, avoiding recently used tasks.
95
  """
96
  pool = ALL_CASES[difficulty]
97
 
@@ -101,7 +190,7 @@ class TaskManager:
101
  return case
102
  raise ValueError(f"Task '{task_id}' not found in {difficulty} pool")
103
 
104
- # Avoid repeating recently used tasks
105
  available = [c for c in pool if c["id"] not in self._used_ids]
106
  if not available:
107
  self._used_ids.clear()
@@ -112,66 +201,92 @@ class TaskManager:
112
  return task
113
 
114
  def get_random_task(self) -> dict:
115
- """Pick a random task from any difficulty."""
116
  difficulty = random.choice(list(DifficultyLevel))
117
  return self.get_task(difficulty)
118
 
 
 
 
 
 
 
 
 
 
 
119
  def build_observation_context(self, task: dict) -> dict:
120
  """
121
- Build the current_context dict for the Observation.
122
- CRITICAL: Must NOT leak the fixed_query (ground truth) to the agent.
 
123
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  context = {
125
- "buggy_query": task["buggy_query"],
126
- "error_message": task["error_message"],
127
- "database_schema": task["database_schema"],
128
- "error_type_hint": task["error_type"],
129
- "category": task["category"],
130
- "estimated_steps": task["estimated_fix_steps"],
131
  }
132
-
133
- # For performance tasks include extra context
134
  if task.get("performance_issue"):
135
  context["performance_issue"] = {
136
  "type": task["performance_issue"]["type"],
137
  "impact": task["performance_issue"]["impact"],
138
- # Do NOT include timing numbers β€” agent must figure it out
139
  }
140
-
141
- # Include expected output shape (but not the fixed query!)
142
  if task.get("expected_output") and isinstance(task["expected_output"], list):
143
  context["expected_output_sample"] = task["expected_output"][:1]
144
-
145
  return context
146
 
147
  def get_hint(self, task: dict, hint_number: int) -> str:
148
- """
149
- Returns progressive hints. Each hint gives more info.
150
- Hints cost -0.05 reward each (handled in reward.py).
151
- """
152
- hints = [
153
- f"Hint 1: The error is in the {task.get('error_location', 'query')}.",
154
- f"Hint 2: This is a {task.get('error_type', 'unknown')} type error. Category: {task.get('category')}.",
155
- f"Hint 3: Fix description β€” {task.get('fix_description', 'Review the query carefully.')}",
156
- ]
 
 
 
 
 
 
157
  idx = min(hint_number - 1, len(hints) - 1)
158
- return hints[idx]
159
 
160
  def list_all_tasks(self) -> list[TaskInfo]:
161
- """Returns TaskInfo list for the /tasks endpoint."""
162
  result = []
163
  for difficulty, cases in ALL_CASES.items():
164
  for case in cases:
165
  result.append(TaskInfo(
166
- id=case["id"],
167
- difficulty=difficulty,
168
- description=case["description"],
169
- action_schema=ACTION_SCHEMA
170
  ))
171
  return result
172
 
173
  def get_ground_truth(self, task_id: str) -> dict | None:
174
- """Returns the full ground truth for a task (used by grader only)."""
175
  for cases in ALL_CASES.values():
176
  for case in cases:
177
  if case["id"] == task_id:
@@ -180,4 +295,4 @@ class TaskManager:
180
 
181
 
182
  # Singleton instance
183
- task_manager = TaskManager()
 
3
  from pathlib import Path
4
  from env.models import DifficultyLevel, TaskInfo
5
 
6
+ # ─────────────────────────────────────────────
7
+ # LOAD DATASETS β€” Round 1 + Round 2
8
+ # ─────────────────────────────────────────────
9
 
10
  BASE_DIR = Path(__file__).parent.parent / "dataset"
11
 
12
+
13
  def _load(filename: str) -> list[dict]:
14
  path = BASE_DIR / filename
15
  with open(path, "r", encoding="utf-8") as f:
16
  return json.load(f)
17
 
18
+
19
+ # Round 1 cases (keep for backward compatibility)
20
  EASY_CASES = _load("easy_cases.json")
21
  MEDIUM_CASES = _load("medium_cases.json")
22
  HARD_CASES = _load("hard_cases.json")
23
 
24
+ # Round 2 scenarios (new long-horizon DB engineering tasks)
25
+ EASY_SCENARIOS = _load("easy_scenarios.json")
26
+ MEDIUM_SCENARIOS = _load("medium_scenarios.json")
27
+ HARD_SCENARIOS = _load("hard_scenarios.json")
28
+
29
+ # Combined pools β€” Round 2 scenarios take priority (listed first)
30
  ALL_CASES: dict[str, list[dict]] = {
31
+ DifficultyLevel.EASY: EASY_SCENARIOS + EASY_CASES,
32
+ DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS + MEDIUM_CASES,
33
+ DifficultyLevel.HARD: HARD_SCENARIOS + HARD_CASES,
34
  }
35
 
36
+ # Round 2 only (for training pipeline)
37
+ SCENARIO_ONLY: dict[str, list[dict]] = {
38
+ DifficultyLevel.EASY: EASY_SCENARIOS,
39
+ DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS,
40
+ DifficultyLevel.HARD: HARD_SCENARIOS,
41
+ }
42
+
43
+
44
+ # ─────────────────────────────────────────────
45
  # ACTION SCHEMA (required by /tasks validator)
46
+ # ─────────────────────────────────────────────
47
 
48
  ACTION_SCHEMA = {
49
+ # ── Round 1 actions ──────────────────────────────────────────
50
  "identify_error": {
51
  "description": "Identify where and what the error is without fixing it yet",
52
  "payload_fields": {
 
58
  "propose_fix": {
59
  "description": "Propose a fix without submitting as final answer",
60
  "payload_fields": {
61
+ "fixed_query": {"type": "string", "required": True, "description": "The proposed corrected SQL query"},
62
+ "change_made": {"type": "string", "required": True, "description": "What specifically was changed"},
63
+ "confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"}
64
  }
65
  },
66
  "submit_answer": {
67
  "description": "Submit the final fixed query as the definitive answer",
68
  "payload_fields": {
69
+ "fixed_query": {"type": "string", "required": True, "description": "Final corrected SQL query"},
70
+ "explanation": {"type": "string", "required": True, "description": "Full explanation of fix"},
71
+ "error_type": {"type": "string", "required": False, "description": "syntax | logic | performance"},
72
+ "confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
73
  }
74
  },
75
  "request_hint": {
76
+ "description": "Request a hint β€” costs 0.10 reward penalty per hint",
77
  "payload_fields": {
78
+ "hint_type": {"type": "string", "required": False, "description": "location | error_type | fix_direction"}
79
  }
80
  },
81
  "explain_issue": {
82
+ "description": "Explain the issue in detail",
83
  "payload_fields": {
84
+ "explanation": {"type": "string", "required": True, "description": "Detailed explanation"},
85
+ "impact": {"type": "string", "required": False, "description": "Impact on query performance"},
86
+ "root_cause": {"type": "string", "required": False, "description": "Root cause analysis"}
87
  }
88
  },
89
  "optimize_query": {
90
+ "description": "Submit an optimized version of the query",
91
+ "payload_fields": {
92
+ "optimized_query": {"type": "string", "required": True, "description": "Optimized SQL"},
93
+ "optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"},
94
+ "expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain"},
95
+ "explanation": {"type": "string", "required": False, "description": "Why this optimization works"},
96
+ "confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"}
97
+ }
98
+ },
99
+ # ── Round 2 actions ──────────────────────────────────────────
100
+ "inspect_query": {
101
+ "description": "EXPLAIN a slow query β€” reveals scan type, rows examined, index usage",
102
+ "payload_fields": {
103
+ "query_id": {"type": "string", "required": True, "description": "ID of slow query to inspect (e.g. 'q1')"}
104
+ }
105
+ },
106
+ "analyze_indexes": {
107
+ "description": "Show all indexes on a table + usage frequency + missing index hints",
108
  "payload_fields": {
109
+ "table": {"type": "string", "required": True, "description": "Table name to analyze"}
 
 
 
 
110
  }
111
+ },
112
+ "create_index": {
113
+ "description": "Add a composite index on specified columns β€” core optimization action",
114
+ "payload_fields": {
115
+ "table": {"type": "string", "required": True, "description": "Table to index"},
116
+ "columns": {"type": "list|string", "required": True, "description": "Columns to index (list or comma-separated string)"}
117
+ }
118
+ },
119
+ "rewrite_query": {
120
+ "description": "Submit a rewritten SQL query β€” system evaluates execution time improvement",
121
+ "payload_fields": {
122
+ "query_id": {"type": "string", "required": True, "description": "ID of query to rewrite"},
123
+ "new_sql": {"type": "string", "required": True, "description": "Rewritten SQL query"}
124
+ }
125
+ },
126
+ "add_column": {
127
+ "description": "Add a denormalization column to reduce expensive JOINs",
128
+ "payload_fields": {
129
+ "table": {"type": "string", "required": True, "description": "Table to modify"},
130
+ "column": {"type": "string", "required": True, "description": "New column name"},
131
+ "purpose": {"type": "string", "required": False, "description": "Why this column helps"}
132
+ }
133
+ },
134
+ "drop_index": {
135
+ "description": "Remove an unused index to reduce write overhead",
136
+ "payload_fields": {
137
+ "table": {"type": "string", "required": True, "description": "Table name"},
138
+ "index_name": {"type": "string", "required": True, "description": "Index name to drop (cannot drop PRIMARY)"}
139
+ }
140
+ },
141
+ "partition_table": {
142
+ "description": "Partition a large table by date or ID range for range query efficiency",
143
+ "payload_fields": {
144
+ "table": {"type": "string", "required": True, "description": "Table to partition"},
145
+ "partition_by": {"type": "string", "required": False, "description": "Column to partition on (e.g. 'created_at')"},
146
+ "partition_type": {"type": "string", "required": False, "description": "RANGE | LIST | HASH"}
147
+ }
148
+ },
149
+ "analyze_statistics": {
150
+ "description": "Update table statistics for query planner accuracy",
151
+ "payload_fields": {
152
+ "table": {"type": "string", "required": True, "description": "Table to analyze"}
153
+ }
154
+ },
155
+ "submit_report": {
156
+ "description": "TERMINAL: Submit final optimization report β€” ends episode, computes full score",
157
+ "payload_fields": {
158
+ "summary": {"type": "string", "required": True, "description": "Summary of optimizations applied"},
159
+ "actions_taken": {"type": "list", "required": False, "description": "List of key actions taken"},
160
+ "expected_gain": {"type": "string", "required": False, "description": "Expected performance improvement"}
161
+ }
162
+ },
163
  }
 
164
 
165
 
166
+ # ─────────────────────────────────────────────
167
+ # TASK MANAGER
168
+ # ─────────────────────────────────────────────
169
+
170
  class TaskManager:
171
  """
172
+ Manages task selection for both Round 1 and Round 2 scenarios.
173
+ Round 2 scenarios have tables/slow_queries structure.
174
+ Round 1 cases have buggy_query structure.
175
  """
176
 
177
  def __init__(self):
 
179
 
180
  def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict:
181
  """
182
+ Returns a task for the given difficulty.
183
+ Prefers Round 2 scenarios, falls back to Round 1 cases.
 
184
  """
185
  pool = ALL_CASES[difficulty]
186
 
 
190
  return case
191
  raise ValueError(f"Task '{task_id}' not found in {difficulty} pool")
192
 
193
+ # Avoid recently used tasks
194
  available = [c for c in pool if c["id"] not in self._used_ids]
195
  if not available:
196
  self._used_ids.clear()
 
201
  return task
202
 
203
  def get_random_task(self) -> dict:
 
204
  difficulty = random.choice(list(DifficultyLevel))
205
  return self.get_task(difficulty)
206
 
207
+ def get_scenario(self, difficulty: DifficultyLevel, scenario_id: str | None = None) -> dict:
208
+ """Get Round 2 scenario specifically."""
209
+ pool = SCENARIO_ONLY[difficulty]
210
+ if scenario_id:
211
+ for s in pool:
212
+ if s["id"] == scenario_id:
213
+ return s
214
+ raise ValueError(f"Scenario '{scenario_id}' not found")
215
+ return random.choice(pool)
216
+
217
  def build_observation_context(self, task: dict) -> dict:
218
  """
219
+ Builds current_context for the Observation.
220
+ Handles both Round 2 scenario format and Round 1 case format.
221
+ CRITICAL: Never leaks ground truth (fixed_query / optimal_actions).
222
  """
223
+ # ── Round 2 scenario format ───────────────────────────────
224
+ if "slow_queries" in task:
225
+ return {
226
+ "scenario_id": task["id"],
227
+ "description": task.get("description", ""),
228
+ "tables": task.get("tables", []),
229
+ "slow_queries": task.get("slow_queries", []),
230
+ "performance_score_baseline": task.get("performance_score_baseline", 0.0),
231
+ "target_score": task.get("target_score", 85.0),
232
+ "max_steps": task.get("max_steps", 50),
233
+ "category": task.get("category", ""),
234
+ # Do NOT include missing_index_hints (that's the answer)
235
+ # Do NOT include optimal_actions (that's the answer)
236
+ }
237
+
238
+ # ── Round 1 case format (backward compatible) ────────────
239
  context = {
240
+ "buggy_query": task.get("buggy_query", ""),
241
+ "error_message": task.get("error_message", ""),
242
+ "database_schema": task.get("database_schema", ""),
243
+ "error_type_hint": task.get("error_type", ""),
244
+ "category": task.get("category", ""),
245
+ "estimated_steps": task.get("estimated_fix_steps", 5),
246
  }
 
 
247
  if task.get("performance_issue"):
248
  context["performance_issue"] = {
249
  "type": task["performance_issue"]["type"],
250
  "impact": task["performance_issue"]["impact"],
 
251
  }
 
 
252
  if task.get("expected_output") and isinstance(task["expected_output"], list):
253
  context["expected_output_sample"] = task["expected_output"][:1]
 
254
  return context
255
 
256
  def get_hint(self, task: dict, hint_number: int) -> str:
257
+ """Progressive hints. Each hint reveals more info. Costs -0.10 each."""
258
+ # Round 2 scenario hints
259
+ if "slow_queries" in task:
260
+ hints = [
261
+ f"Hint 1: Start by inspecting your slow queries with inspect_query action.",
262
+ f"Hint 2: Use analyze_indexes on tables appearing in slow queries.",
263
+ f"Hint 3: Category is '{task.get('category', 'indexing')}'. Target score: {task.get('target_score', 85.0)}.",
264
+ ]
265
+ else:
266
+ # Round 1 hints
267
+ hints = [
268
+ f"Hint 1: The error is in the {task.get('error_location', 'query')}.",
269
+ f"Hint 2: This is a {task.get('error_type', 'unknown')} error. Category: {task.get('category')}.",
270
+ f"Hint 3: Fix: {task.get('fix_description', 'Review the query carefully.')}",
271
+ ]
272
  idx = min(hint_number - 1, len(hints) - 1)
273
+ return hints[max(0, idx)]
274
 
275
  def list_all_tasks(self) -> list[TaskInfo]:
276
+ """Returns TaskInfo list for the /tasks endpoint β€” all 30 tasks."""
277
  result = []
278
  for difficulty, cases in ALL_CASES.items():
279
  for case in cases:
280
  result.append(TaskInfo(
281
+ id = case["id"],
282
+ difficulty = difficulty,
283
+ description = case.get("description", ""),
284
+ action_schema = ACTION_SCHEMA
285
  ))
286
  return result
287
 
288
  def get_ground_truth(self, task_id: str) -> dict | None:
289
+ """Returns full task including ground truth (used by grader only)."""
290
  for cases in ALL_CASES.values():
291
  for case in cases:
292
  if case["id"] == task_id:
 
295
 
296
 
297
  # Singleton instance
298
+ task_manager = TaskManager()
tests/test_environment.py CHANGED
@@ -22,8 +22,7 @@ def test_reset_easy(env):
22
  assert obs.step_count == 0
23
  assert obs.difficulty == DifficultyLevel.EASY
24
  assert "fixed_query" not in obs.current_context
25
- assert "buggy_query" in obs.current_context
26
-
27
 
28
  def test_reset_medium(env):
29
  obs = env.reset(difficulty="medium")
@@ -65,7 +64,7 @@ def test_step_null_action(env):
65
  """Null action must return -0.1, never crash."""
66
  env.reset(difficulty="easy")
67
  resp = env.step(None)
68
- assert resp.reward.score == -0.1
69
  assert resp.done == False
70
 
71
 
@@ -110,7 +109,7 @@ def test_max_steps(env):
110
  action = Action(action_type=ActionType.IDENTIFY_ERROR,
111
  payload={"error_location": "x", "error_type": "syntax"})
112
  done = False
113
- for _ in range(25):
114
  resp = env.step(action)
115
  if resp.done:
116
  done = True
 
22
  assert obs.step_count == 0
23
  assert obs.difficulty == DifficultyLevel.EASY
24
  assert "fixed_query" not in obs.current_context
25
+ assert "buggy_query" in obs.current_context or "slow_queries" in obs.current_context
 
26
 
27
  def test_reset_medium(env):
28
  obs = env.reset(difficulty="medium")
 
64
  """Null action must return -0.1, never crash."""
65
  env.reset(difficulty="easy")
66
  resp = env.step(None)
67
+ assert resp.reward.score >= 0.001
68
  assert resp.done == False
69
 
70
 
 
109
  action = Action(action_type=ActionType.IDENTIFY_ERROR,
110
  payload={"error_location": "x", "error_type": "syntax"})
111
  done = False
112
+ for _ in range(55):
113
  resp = env.step(action)
114
  if resp.done:
115
  done = True
tests/test_graders.py CHANGED
@@ -21,7 +21,7 @@ def test_easy_perfect_score():
21
 
22
  def test_null_action_returns_zero():
23
  score, breakdown, feedback = grade(None, "easy_001")
24
- assert score == 0.0
25
  assert "null" in feedback.lower() or "no action" in feedback.lower()
26
 
27
 
@@ -29,7 +29,7 @@ def test_unknown_task_returns_zero():
29
  action = Action(action_type=ActionType.SUBMIT_ANSWER,
30
  payload={"fixed_query": "SELECT 1", "explanation": "test"})
31
  score, _, _ = grade(action, "nonexistent_task_999")
32
- assert score == 0.0
33
 
34
 
35
  def test_determinism():
 
21
 
22
  def test_null_action_returns_zero():
23
  score, breakdown, feedback = grade(None, "easy_001")
24
+ assert score <= 0.001 # clamped minimum for OpenEnv compliance
25
  assert "null" in feedback.lower() or "no action" in feedback.lower()
26
 
27
 
 
29
  action = Action(action_type=ActionType.SUBMIT_ANSWER,
30
  payload={"fixed_query": "SELECT 1", "explanation": "test"})
31
  score, _, _ = grade(action, "nonexistent_task_999")
32
+ assert score <= 0.001
33
 
34
 
35
  def test_determinism():
training/evaluate_agent.py ADDED
File without changes
training/generate_training_data.py ADDED
File without changes
training/train_agent.py ADDED
File without changes