junaid0600 commited on
Commit
028dbb9
Β·
verified Β·
1 Parent(s): 431b9a5

Update env/graders.py

Browse files
Files changed (1) hide show
  1. env/graders.py +356 -110
env/graders.py CHANGED
@@ -1,38 +1,34 @@
1
  import re
 
 
 
2
  from env.models import Action, DifficultyLevel
3
  from env.tasks import task_manager
4
 
5
  # ─────────────────────────────────────────────
6
- # HELPERS
7
  # ─────────────────────────────────────────────
8
 
9
  def _normalize(text: str) -> str:
10
- """Normalize SQL for comparison β€” lowercase, strip whitespace, collapse spaces."""
11
  if not isinstance(text, str):
12
  return ""
13
  return re.sub(r"\s+", " ", text.strip().lower())
14
 
15
  def _safe_get(payload: dict, key: str, default=None):
16
- """Safe dict access β€” never KeyError."""
17
  if not isinstance(payload, dict):
18
  return default
19
  return payload.get(key, default)
20
 
21
  def _score_explanation(explanation: str) -> float:
22
- """Score explanation quality by length and keyword richness."""
23
  if not explanation or not isinstance(explanation, str):
24
  return 0.0
25
  explanation = explanation.strip()
26
- if len(explanation) < 10:
27
- return 0.0
28
- if len(explanation) < 30:
29
- return 0.05
30
- if len(explanation) < 80:
31
- return 0.10
32
  return 0.15
33
 
34
  def _score_confidence(confidence) -> float:
35
- """Give partial credit for providing a valid confidence score."""
36
  try:
37
  c = float(confidence)
38
  if 0.0 <= c <= 1.0:
@@ -42,32 +38,21 @@ def _score_confidence(confidence) -> float:
42
  return 0.0
43
 
44
  def _query_similarity(submitted: str, expected: str) -> float:
45
- """
46
- Multi-level SQL similarity check.
47
- Returns 0.0 - 1.0 based on how close the submitted query is to expected.
48
- """
49
  s = _normalize(submitted)
50
  e = _normalize(expected)
51
-
52
  if s == e:
53
  return 1.0
54
-
55
  s_tokens = set(s.split())
56
  e_tokens = set(e.split())
57
-
58
  if not e_tokens:
59
  return 0.0
60
-
61
  overlap = len(s_tokens & e_tokens) / len(e_tokens)
62
-
63
  critical_keywords = _extract_critical_keywords(e)
64
  critical_found = sum(1 for kw in critical_keywords if kw in s)
65
  critical_score = critical_found / len(critical_keywords) if critical_keywords else 0.0
66
-
67
  return round((overlap * 0.4) + (critical_score * 0.6), 4)
68
 
69
  def _extract_critical_keywords(query: str) -> list[str]:
70
- """Extract SQL keywords that are critical to correctness."""
71
  keywords = [
72
  "left join", "inner join", "right join",
73
  "group by", "order by", "having",
@@ -84,7 +69,6 @@ def _extract_critical_keywords(query: str) -> list[str]:
84
  return found
85
 
86
  def _score_error_type(submitted_type: str, expected_type: str) -> float:
87
- """Score for correctly identifying the error type."""
88
  if not submitted_type:
89
  return 0.0
90
  s = submitted_type.strip().lower()
@@ -102,7 +86,6 @@ def _score_error_type(submitted_type: str, expected_type: str) -> float:
102
  return 0.0
103
 
104
  def _score_error_location(submitted_location: str, expected_location: str) -> float:
105
- """Score for correctly identifying WHERE in the query the error is."""
106
  if not submitted_location or not expected_location:
107
  return 0.0
108
  s = submitted_location.strip().lower()
@@ -116,120 +99,395 @@ def _score_error_location(submitted_location: str, expected_location: str) -> fl
116
 
117
 
118
  # ─────────────────────────────────────────────
119
- # GRADERS PER DIFFICULTY
120
  # ─────────────────────────────────────────────
121
 
122
- def grade_easy(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  """
124
- Easy task grader β€” syntax errors.
125
- Max score: 1.0
126
- DETERMINISTIC: same input always returns same score.
 
 
 
 
 
 
 
 
127
  """
128
  if action is None or action.payload is None:
129
- return 0.0, {"error": "null_action"}, "No action provided."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  payload = action.payload
132
  score = 0.0
133
  breakdown = {}
134
  feedback_parts = []
135
 
136
- # ── 1. Query fix correctness (0.50) ──────────────────────────
137
  submitted_query = _safe_get(payload, "fixed_query", "") or _safe_get(payload, "optimized_query", "")
138
  expected_query = ground_truth.get("fixed_query", "")
139
  similarity = _query_similarity(submitted_query, expected_query)
140
 
141
  if similarity >= 1.0:
142
- fix_score = 0.50
143
- feedback_parts.append("Correct fix applied.")
144
  elif similarity >= 0.75:
145
- fix_score = 0.30
146
- feedback_parts.append("Fix is mostly correct but has minor differences.")
147
  elif similarity >= 0.50:
148
- fix_score = 0.15
149
- feedback_parts.append("Fix is partially correct.")
150
  else:
151
- fix_score = 0.0
152
- feedback_parts.append("Fix is incorrect or not provided.")
153
 
154
  score += fix_score
155
  breakdown["fix_correctness"] = round(fix_score, 4)
156
 
157
- # ── 2. Error location (0.15) ─────────────────────────────────
158
  submitted_location = _safe_get(payload, "error_location", "")
159
  expected_location = ground_truth.get("error_location", "")
160
  loc_score = _score_error_location(str(submitted_location), expected_location)
161
  score += loc_score
162
  breakdown["error_location"] = round(loc_score, 4)
163
- if loc_score > 0:
164
- feedback_parts.append("Correctly identified error location.")
165
 
166
- # ── 3. Error type (0.10) ─────────────────────────────────────
167
  submitted_type = _safe_get(payload, "error_type", "")
168
  expected_type = ground_truth.get("error_type", "syntax")
169
  type_score = _score_error_type(str(submitted_type), expected_type)
170
  score += type_score
171
  breakdown["error_type"] = round(type_score, 4)
172
- if type_score > 0:
173
- feedback_parts.append("Correctly identified error type.")
174
 
175
- # ── 4. Explanation quality (0.15) ────────────────────────────
176
  explanation = _safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "")
177
  expl_score = _score_explanation(str(explanation) if explanation else "")
178
  score += expl_score
179
  breakdown["explanation"] = round(expl_score, 4)
180
- if expl_score > 0:
181
- feedback_parts.append("Explanation provided.")
182
 
183
- # ── 5. Confidence (0.05) ─────────────────────────────────────
184
  confidence = _safe_get(payload, "confidence", None)
185
  conf_score = _score_confidence(confidence)
186
  score += conf_score
187
  breakdown["confidence"] = round(conf_score, 4)
188
 
189
- final_score = round(max(0.0, min(1.0, score)), 4)
190
  feedback = " ".join(feedback_parts) if feedback_parts else "No valid response provided."
191
  return final_score, breakdown, feedback
192
 
193
 
194
  def grade_medium(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
195
- """
196
- Medium task grader β€” logic errors.
197
- Max score: 1.0
198
- DETERMINISTIC: same input always returns same score.
199
- """
200
  if action is None or action.payload is None:
201
- return 0.0, {"error": "null_action"}, "No action provided."
202
 
203
  payload = action.payload
204
  score = 0.0
205
  breakdown = {}
206
  feedback_parts = []
207
 
208
- # ── 1. Query fix correctness (0.40) ──────────────────────────
209
  submitted_query = _safe_get(payload, "fixed_query", "") or _safe_get(payload, "optimized_query", "")
210
  expected_query = ground_truth.get("fixed_query", "")
211
  similarity = _query_similarity(submitted_query, expected_query)
212
 
213
  if similarity >= 1.0:
214
- fix_score = 0.40
215
- feedback_parts.append("Correct fix applied.")
216
  elif similarity >= 0.80:
217
- fix_score = 0.28
218
- feedback_parts.append("Fix is mostly correct.")
219
  elif similarity >= 0.60:
220
- fix_score = 0.16
221
- feedback_parts.append("Fix is partially correct.")
222
  elif similarity >= 0.40:
223
- fix_score = 0.08
224
- feedback_parts.append("Fix shows some understanding.")
225
  else:
226
- fix_score = 0.0
227
- feedback_parts.append("Fix is incorrect or missing.")
228
 
229
  score += fix_score
230
  breakdown["fix_correctness"] = round(fix_score, 4)
231
 
232
- # ── 2. Logic flaw identification (0.20) ──────────────────────
233
  explanation = str(_safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "") or "")
234
  error_type = ground_truth.get("error_type", "logic")
235
 
@@ -238,35 +496,29 @@ def grade_medium(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
238
  "aggregate", "subquery", "correlation", "distinct", "count"],
239
  "performance": ["index", "scan", "n+1", "correlated", "cartesian", "window"]
240
  }
241
-
242
  keywords_to_check = logic_keywords.get(error_type, logic_keywords["logic"])
243
  expl_lower = explanation.lower()
244
  keyword_hits = sum(1 for kw in keywords_to_check if kw in expl_lower)
245
  logic_score = min(keyword_hits * 0.05, 0.20)
246
  score += logic_score
247
  breakdown["logic_flaw_identification"] = round(logic_score, 4)
248
- if logic_score > 0:
249
- feedback_parts.append("Shows understanding of the logic flaw.")
250
 
251
- # ── 3. Error location (0.15) ─────────────────────────────────
252
  submitted_location = _safe_get(payload, "error_location", "")
253
  expected_location = ground_truth.get("error_location", "")
254
  loc_score = _score_error_location(str(submitted_location), expected_location)
255
  score += loc_score
256
  breakdown["error_location"] = round(loc_score, 4)
257
 
258
- # ── 4. Explanation quality (0.15) ────────────────────────────
259
  expl_score = _score_explanation(explanation)
260
  score += expl_score
261
  breakdown["explanation"] = round(expl_score, 4)
262
 
263
- # ── 5. Confidence (0.05) ─────────────────────────────────────
264
  confidence = _safe_get(payload, "confidence", None)
265
  conf_score = _score_confidence(confidence)
266
  score += conf_score
267
  breakdown["confidence"] = round(conf_score, 4)
268
 
269
- # ── 6. Impact analysis bonus (0.05) ──────────────────────────
270
  impact = str(_safe_get(payload, "impact", "") or "")
271
  if len(impact.strip()) > 20:
272
  score += 0.05
@@ -275,27 +527,20 @@ def grade_medium(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
275
  else:
276
  breakdown["impact_analysis"] = 0.0
277
 
278
- final_score = round(max(0.0, min(1.0, score)), 4)
279
  feedback = " ".join(feedback_parts) if feedback_parts else "No valid response provided."
280
  return final_score, breakdown, feedback
281
 
282
 
283
  def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
284
- """
285
- Hard task grader β€” performance issues.
286
- Max score: 1.0
287
- Frontier models expected ~0.10-0.20.
288
- DETERMINISTIC: same input always returns same score.
289
- """
290
  if action is None or action.payload is None:
291
- return 0.0, {"error": "null_action"}, "No action provided."
292
 
293
  payload = action.payload
294
  score = 0.0
295
  breakdown = {}
296
  feedback_parts = []
297
 
298
- # ── 1. Query correctness (0.30) ──────────────────────────────
299
  submitted_query = (
300
  _safe_get(payload, "optimized_query", "")
301
  or _safe_get(payload, "fixed_query", "")
@@ -305,25 +550,19 @@ def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
305
  similarity = _query_similarity(submitted_query, expected_query)
306
 
307
  if similarity >= 1.0:
308
- fix_score = 0.30
309
- feedback_parts.append("Perfectly optimized query.")
310
  elif similarity >= 0.85:
311
- fix_score = 0.22
312
- feedback_parts.append("Query is mostly correct.")
313
  elif similarity >= 0.65:
314
- fix_score = 0.14
315
- feedback_parts.append("Query shows correct approach but incomplete.")
316
  elif similarity >= 0.40:
317
- fix_score = 0.07
318
- feedback_parts.append("Query partially addresses the issue.")
319
  else:
320
- fix_score = 0.0
321
- feedback_parts.append("Query does not address the performance issue.")
322
 
323
  score += fix_score
324
  breakdown["query_correctness"] = round(fix_score, 4)
325
 
326
- # ── 2. Performance concept identification (0.30) ──────────────
327
  explanation = str(_safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "") or "")
328
  optimization = str(_safe_get(payload, "optimization_type", "") or "")
329
  combined_text = (explanation + " " + optimization).lower()
@@ -347,17 +586,14 @@ def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
347
 
348
  score += concept_score
349
  breakdown["performance_concept"] = round(concept_score, 4)
350
- if concept_score > 0:
351
- feedback_parts.append("Demonstrates understanding of the performance issue.")
352
 
353
- # ── 3. Explanation depth (0.15) ───────────────────────────────
354
  expl_score = _score_explanation(explanation)
355
  if len(explanation.strip()) > 150:
356
  expl_score = min(expl_score + 0.05, 0.15)
357
  score += expl_score
358
  breakdown["explanation_depth"] = round(expl_score, 4)
359
 
360
- # ── 4. Root cause analysis (0.10) ─────────────────────────────
361
  root_cause = str(_safe_get(payload, "root_cause", "") or "")
362
  if len(root_cause.strip()) > 30:
363
  score += 0.10
@@ -366,7 +602,6 @@ def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
366
  else:
367
  breakdown["root_cause_analysis"] = 0.0
368
 
369
- # ── 5. Expected improvement (0.10) ────────────────────────────
370
  improvement = str(_safe_get(payload, "expected_improvement", "") or "")
371
  if len(improvement.strip()) > 20:
372
  score += 0.10
@@ -375,13 +610,12 @@ def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
375
  else:
376
  breakdown["expected_improvement"] = 0.0
377
 
378
- # ── 6. Confidence (0.05) ──────────────────────────────────────
379
  confidence = _safe_get(payload, "confidence", None)
380
  conf_score = _score_confidence(confidence)
381
  score += conf_score
382
  breakdown["confidence"] = round(conf_score, 4)
383
 
384
- final_score = round(max(0.0, min(1.0, score)), 4)
385
  feedback = " ".join(feedback_parts) if feedback_parts else "Performance issue not identified."
386
  return final_score, breakdown, feedback
387
 
@@ -393,18 +627,30 @@ def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
393
  def grade(action: Action, task_id: str) -> tuple[float, dict, str]:
394
  """
395
  Main grader entry point.
396
- Looks up ground truth, dispatches to correct grader by difficulty.
397
- ALWAYS returns (float, dict, str) β€” never crashes.
398
- Score is always between 0.0 and 1.0.
 
 
 
 
 
 
 
399
  """
400
  if action is None:
401
- return 0.0, {"error": "null_action"}, "No action provided."
 
 
 
 
402
 
 
403
  ground_truth = task_manager.get_ground_truth(task_id)
404
  if ground_truth is None:
405
- return 0.0, {"error": "unknown_task"}, f"Task '{task_id}' not found."
406
 
407
- difficulty = ground_truth.get("id", "").split("_")[0]
408
 
409
  try:
410
  if difficulty == "easy":
@@ -414,6 +660,6 @@ def grade(action: Action, task_id: str) -> tuple[float, dict, str]:
414
  elif difficulty == "hard":
415
  return grade_hard(action, ground_truth)
416
  else:
417
- return 0.0, {"error": "unknown_difficulty"}, f"Unknown difficulty: {difficulty}"
418
  except Exception as e:
419
- return 0.0, {"error": str(e)}, f"Grader error: {str(e)}"
 
1
  import re
2
+ import json
3
+ import os
4
+ from functools import lru_cache
5
  from env.models import Action, DifficultyLevel
6
  from env.tasks import task_manager
7
 
8
  # ─────────────────────────────────────────────
9
+ # HELPERS (unchanged from Round 1)
10
  # ─────────────────────────────────────────────
11
 
12
  def _normalize(text: str) -> str:
 
13
  if not isinstance(text, str):
14
  return ""
15
  return re.sub(r"\s+", " ", text.strip().lower())
16
 
17
  def _safe_get(payload: dict, key: str, default=None):
 
18
  if not isinstance(payload, dict):
19
  return default
20
  return payload.get(key, default)
21
 
22
  def _score_explanation(explanation: str) -> float:
 
23
  if not explanation or not isinstance(explanation, str):
24
  return 0.0
25
  explanation = explanation.strip()
26
+ if len(explanation) < 10: return 0.0
27
+ if len(explanation) < 30: return 0.05
28
+ if len(explanation) < 80: return 0.10
 
 
 
29
  return 0.15
30
 
31
  def _score_confidence(confidence) -> float:
 
32
  try:
33
  c = float(confidence)
34
  if 0.0 <= c <= 1.0:
 
38
  return 0.0
39
 
40
  def _query_similarity(submitted: str, expected: str) -> float:
 
 
 
 
41
  s = _normalize(submitted)
42
  e = _normalize(expected)
 
43
  if s == e:
44
  return 1.0
 
45
  s_tokens = set(s.split())
46
  e_tokens = set(e.split())
 
47
  if not e_tokens:
48
  return 0.0
 
49
  overlap = len(s_tokens & e_tokens) / len(e_tokens)
 
50
  critical_keywords = _extract_critical_keywords(e)
51
  critical_found = sum(1 for kw in critical_keywords if kw in s)
52
  critical_score = critical_found / len(critical_keywords) if critical_keywords else 0.0
 
53
  return round((overlap * 0.4) + (critical_score * 0.6), 4)
54
 
55
  def _extract_critical_keywords(query: str) -> list[str]:
 
56
  keywords = [
57
  "left join", "inner join", "right join",
58
  "group by", "order by", "having",
 
69
  return found
70
 
71
  def _score_error_type(submitted_type: str, expected_type: str) -> float:
 
72
  if not submitted_type:
73
  return 0.0
74
  s = submitted_type.strip().lower()
 
86
  return 0.0
87
 
88
  def _score_error_location(submitted_location: str, expected_location: str) -> float:
 
89
  if not submitted_location or not expected_location:
90
  return 0.0
91
  s = submitted_location.strip().lower()
 
99
 
100
 
101
  # ─────────────────────────────────────────────
102
+ # ROUND 2 β€” SCENARIO LOADER
103
  # ─────────────────────────────────────────────
104
 
105
+ # Cache for loaded scenarios β€” avoids re-reading JSON on every grader call
106
+ _scenario_cache: dict[str, dict] = {}
107
+ _cache_loaded = False
108
+
109
+ def _load_all_scenarios():
110
+ """Load all Round 2 scenario JSONs into cache once at startup."""
111
+ global _cache_loaded
112
+ if _cache_loaded:
113
+ return
114
+ for fname in [
115
+ "dataset/easy_scenarios.json",
116
+ "dataset/medium_scenarios.json",
117
+ "dataset/hard_scenarios.json",
118
+ ]:
119
+ try:
120
+ with open(fname) as f:
121
+ for s in json.load(f):
122
+ _scenario_cache[s["id"]] = s
123
+ except FileNotFoundError:
124
+ pass
125
+ except Exception:
126
+ pass
127
+ _cache_loaded = True
128
+
129
+ def _get_scenario(task_id: str) -> dict | None:
130
+ """Get a Round 2 scenario by ID. Returns None if not found."""
131
+ _load_all_scenarios()
132
+ return _scenario_cache.get(task_id)
133
+
134
+ def _is_scenario_task(task_id: str) -> bool:
135
+ """
136
+ Round 2 scenario IDs have format: easy_s001, medium_s002, hard_s003.
137
+ Round 1 task IDs have format: easy_001, medium_001, hard_001.
138
+ Distinction: Round 2 has 's' before the number.
139
+ """
140
+ if not task_id:
141
+ return False
142
+ parts = task_id.split("_")
143
+ # easy_s001 β†’ ["easy", "s001"] | easy_001 β†’ ["easy", "001"]
144
+ return len(parts) >= 2 and parts[-1].startswith("s")
145
+
146
+
147
+ # ─────────────────────────────────────────────
148
+ # ROUND 2 β€” DB ACTION GRADER
149
+ # ─────────────────────────────────────────────
150
+
151
+ def grade_db_action(action: Action, task_id: str) -> tuple[float, dict, str]:
152
  """
153
+ Grades a Round 2 database engineering action.
154
+
155
+ Scoring philosophy:
156
+ - Does the action target valid tables/queries in THIS scenario?
157
+ - For create_index: does it match the missing_index_hints?
158
+ - For rewrite_query: is the SQL structurally better?
159
+ - For submit_report: was a meaningful summary provided?
160
+ - All terminal/non-terminal actions get meaningful differentiation.
161
+
162
+ Returns (score 0.001-0.999, breakdown dict, feedback string).
163
+ DETERMINISTIC: same input β†’ same score always.
164
  """
165
  if action is None or action.payload is None:
166
+ return 0.001, {"error": "null_action"}, "No action provided."
167
+
168
+ scenario = _get_scenario(task_id)
169
+ if scenario is None:
170
+ # Unknown scenario β€” give a small score for valid action structure
171
+ return 0.10, {"error": "scenario_not_found"}, f"Scenario '{task_id}' not in dataset."
172
+
173
+ action_type = (
174
+ action.action_type.value
175
+ if hasattr(action.action_type, "value")
176
+ else str(action.action_type)
177
+ )
178
+ payload = action.payload or {}
179
+
180
+ valid_tables = {t["name"] for t in scenario.get("tables", [])}
181
+ valid_queries = {q["id"] for q in scenario.get("slow_queries", [])}
182
+ hints = scenario.get("missing_index_hints", [])
183
+ large_tables = {
184
+ t["name"] for t in scenario.get("tables", [])
185
+ if t.get("rows", 0) > 100_000
186
+ }
187
+
188
+ score = 0.0
189
+ breakdown = {}
190
+ feedback = []
191
+
192
+ # ── inspect_query ─────────────────────────────────────────────
193
+ if action_type == "inspect_query":
194
+ qid = str(payload.get("query_id", "")).strip()
195
+ if qid in valid_queries:
196
+ score = 0.40
197
+ feedback.append(f"Inspecting valid slow query '{qid}'.")
198
+ breakdown["query_valid"] = 0.40
199
+ elif qid:
200
+ score = 0.10
201
+ feedback.append(f"Query '{qid}' not in scenario slow_queries.")
202
+ breakdown["query_valid"] = 0.10
203
+ else:
204
+ score = 0.05
205
+ feedback.append("No query_id provided in payload.")
206
+ breakdown["query_valid"] = 0.05
207
+
208
+ # ── analyze_indexes ───────────────────────────────────────────
209
+ elif action_type == "analyze_indexes":
210
+ table = str(payload.get("table", "")).strip()
211
+ if table in valid_tables:
212
+ score = 0.35
213
+ feedback.append(f"Analyzing indexes on valid table '{table}'.")
214
+ breakdown["table_valid"] = 0.35
215
+ elif table:
216
+ score = 0.08
217
+ feedback.append(f"Table '{table}' not in scenario.")
218
+ breakdown["table_valid"] = 0.08
219
+ else:
220
+ score = 0.05
221
+ feedback.append("No table provided in payload.")
222
+ breakdown["table_valid"] = 0.05
223
+
224
+ # ── create_index ──────────────────────────────────────────────
225
+ elif action_type == "create_index":
226
+ table = str(payload.get("table", "")).strip()
227
+ cols = payload.get("columns", [])
228
+
229
+ # Normalise columns: accept list or comma-string
230
+ if isinstance(cols, str):
231
+ cols = [c.strip() for c in cols.split(",") if c.strip()]
232
+ elif not isinstance(cols, list):
233
+ cols = []
234
+
235
+ if table not in valid_tables:
236
+ score = 0.05
237
+ feedback.append(f"Table '{table}' not in scenario.")
238
+ breakdown["table_valid"] = 0.05
239
+ elif not cols:
240
+ score = 0.10
241
+ feedback.append("Table valid but no columns specified.")
242
+ breakdown["columns_valid"] = 0.10
243
+ else:
244
+ # Score against missing_index_hints
245
+ best_match = 0.0
246
+ for hint in hints:
247
+ if hint.get("table") == table:
248
+ hint_cols = set(hint.get("columns", []))
249
+ submitted_cols = set(cols)
250
+ if hint_cols and submitted_cols:
251
+ overlap = len(hint_cols & submitted_cols) / len(hint_cols)
252
+ best_match = max(best_match, overlap)
253
+
254
+ if best_match >= 1.0:
255
+ score = 0.85
256
+ feedback.append(
257
+ f"Perfect index on {table}({', '.join(cols)}) β€” "
258
+ "matches missing_index_hints exactly."
259
+ )
260
+ breakdown["index_match"] = 0.85
261
+ elif best_match >= 0.5:
262
+ score = 0.55
263
+ feedback.append(
264
+ f"Partial index match on {table} ({int(best_match*100)}% column overlap)."
265
+ )
266
+ breakdown["index_match"] = 0.55
267
+ elif hints:
268
+ # Table valid, hints exist but columns don't match
269
+ score = 0.20
270
+ feedback.append(
271
+ f"Table '{table}' is valid but columns {cols} don't match any hint."
272
+ )
273
+ breakdown["index_match"] = 0.20
274
+ else:
275
+ # No hints in scenario β€” any reasonable index gets credit
276
+ score = 0.35
277
+ feedback.append(f"Index on {table}({', '.join(cols)}) β€” no hints to verify against.")
278
+ breakdown["index_match"] = 0.35
279
+
280
+ # ── rewrite_query ─────────────────────────────────────────────
281
+ elif action_type == "rewrite_query":
282
+ qid = str(payload.get("query_id", "")).strip()
283
+ new_sql = str(payload.get("new_sql", "")).strip()
284
+
285
+ base = 0.0
286
+ if qid in valid_queries:
287
+ base = 0.20
288
+ feedback.append(f"Rewriting valid query '{qid}'.")
289
+ elif qid:
290
+ base = 0.05
291
+ feedback.append(f"Query '{qid}' not in scenario.")
292
+ else:
293
+ base = 0.03
294
+ feedback.append("No query_id provided.")
295
+
296
+ sql_bonus = 0.0
297
+ if new_sql and len(new_sql) > 15:
298
+ lower = new_sql.lower()
299
+ if "select *" not in lower: sql_bonus += 0.10
300
+ if "join" in lower and "where" in lower: sql_bonus += 0.10
301
+ if "index" in lower or "force index" in lower: sql_bonus += 0.08
302
+ if "left join" in lower or "inner join" in lower: sql_bonus += 0.05
303
+ feedback.append("SQL provided and has structure.")
304
+ else:
305
+ feedback.append("No new_sql provided.")
306
+
307
+ score = min(base + sql_bonus, 0.65)
308
+ breakdown["rewrite_quality"] = round(score, 4)
309
+
310
+ # ── partition_table ───────────────────────────────────────────
311
+ elif action_type == "partition_table":
312
+ table = str(payload.get("table", "")).strip()
313
+ col = str(payload.get("partition_column", "")).strip()
314
+
315
+ if table in large_tables:
316
+ score = 0.65
317
+ feedback.append(f"Correct β€” '{table}' is large and benefits from partitioning.")
318
+ breakdown["partition_benefit"] = 0.65
319
+ if col:
320
+ score = min(score + 0.10, 0.75)
321
+ feedback.append(f"Partition column '{col}' specified.")
322
+ elif table in valid_tables:
323
+ score = 0.20
324
+ feedback.append(f"Table '{table}' exists but may not need partitioning (check row count).")
325
+ breakdown["partition_benefit"] = 0.20
326
+ else:
327
+ score = 0.05
328
+ feedback.append(f"Table '{table}' not in scenario.")
329
+ breakdown["partition_benefit"] = 0.05
330
+
331
+ # ── analyze_statistics ────────────────────────────────────────
332
+ elif action_type == "analyze_statistics":
333
+ table = str(payload.get("table", "")).strip()
334
+ if table in valid_tables:
335
+ score = 0.30
336
+ feedback.append(f"Analyzing statistics on valid table '{table}'.")
337
+ breakdown["table_valid"] = 0.30
338
+ else:
339
+ score = 0.08
340
+ feedback.append(f"Table '{table}' not in scenario.")
341
+ breakdown["table_valid"] = 0.08
342
+
343
+ # ── drop_index ────────────────────────────────────────────────
344
+ elif action_type == "drop_index":
345
+ table = str(payload.get("table", "")).strip()
346
+ idx = str(payload.get("index_name", "")).strip()
347
+ if table in valid_tables and idx and idx != "PRIMARY":
348
+ score = 0.25
349
+ feedback.append(f"Dropping index '{idx}' on '{table}'.")
350
+ elif idx == "PRIMARY":
351
+ score = 0.001
352
+ feedback.append("Cannot drop PRIMARY index.")
353
+ else:
354
+ score = 0.05
355
+ feedback.append("Invalid table or index_name.")
356
+ breakdown["drop_validity"] = score
357
+
358
+ # ── add_column ─────────────────────────────────────���──────────
359
+ elif action_type == "add_column":
360
+ table = str(payload.get("table", "")).strip()
361
+ col = str(payload.get("column_name", "")).strip()
362
+ if table in valid_tables and col:
363
+ score = 0.25
364
+ feedback.append(f"Adding column '{col}' to '{table}'.")
365
+ else:
366
+ score = 0.05
367
+ feedback.append("Missing table or column_name.")
368
+ breakdown["add_column"] = score
369
+
370
+ # ── request_hint ──────────────────────────────────────────────
371
+ elif action_type == "request_hint":
372
+ # Hint requests are penalised in the environment reward but still valid actions
373
+ score = 0.10
374
+ feedback.append("Hint requested β€” valid but penalised in full episode reward.")
375
+ breakdown["hint_penalty_note"] = 0.10
376
+
377
+ # ── submit_report ─────────────────────────────────────────────
378
+ elif action_type == "submit_report":
379
+ summary = str(payload.get("summary", "")).strip()
380
+ # Score on summary quality β€” episode score handled separately by /grader
381
+ if len(summary) >= 100:
382
+ score = 0.50
383
+ feedback.append("Detailed report submitted.")
384
+ elif len(summary) >= 30:
385
+ score = 0.30
386
+ feedback.append("Brief report submitted.")
387
+ elif summary:
388
+ score = 0.15
389
+ feedback.append("Minimal report submitted.")
390
+ else:
391
+ score = 0.05
392
+ feedback.append("Empty report β€” include a summary of actions taken.")
393
+ breakdown["report_quality"] = score
394
+
395
+ # ── unknown action ────────────────────────────────────────────
396
+ else:
397
+ score = 0.05
398
+ feedback.append(f"Unknown action_type '{action_type}'.")
399
+ breakdown["unknown_action"] = 0.05
400
+
401
+ final_score = round(max(0.001, min(0.999, score)), 4)
402
+ return final_score, breakdown, " ".join(feedback) or "Action processed."
403
+
404
+
405
+ # ─────────────────────────────────────────────
406
+ # ROUND 1 GRADERS (unchanged)
407
+ # ─────────────────────────────────────────────
408
+
409
+ def grade_easy(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
410
+ if action is None or action.payload is None:
411
+ return 0.001, {"error": "null_action"}, "No action provided."
412
 
413
  payload = action.payload
414
  score = 0.0
415
  breakdown = {}
416
  feedback_parts = []
417
 
 
418
  submitted_query = _safe_get(payload, "fixed_query", "") or _safe_get(payload, "optimized_query", "")
419
  expected_query = ground_truth.get("fixed_query", "")
420
  similarity = _query_similarity(submitted_query, expected_query)
421
 
422
  if similarity >= 1.0:
423
+ fix_score = 0.50; feedback_parts.append("Correct fix applied.")
 
424
  elif similarity >= 0.75:
425
+ fix_score = 0.30; feedback_parts.append("Fix is mostly correct but has minor differences.")
 
426
  elif similarity >= 0.50:
427
+ fix_score = 0.15; feedback_parts.append("Fix is partially correct.")
 
428
  else:
429
+ fix_score = 0.0; feedback_parts.append("Fix is incorrect or not provided.")
 
430
 
431
  score += fix_score
432
  breakdown["fix_correctness"] = round(fix_score, 4)
433
 
 
434
  submitted_location = _safe_get(payload, "error_location", "")
435
  expected_location = ground_truth.get("error_location", "")
436
  loc_score = _score_error_location(str(submitted_location), expected_location)
437
  score += loc_score
438
  breakdown["error_location"] = round(loc_score, 4)
439
+ if loc_score > 0: feedback_parts.append("Correctly identified error location.")
 
440
 
 
441
  submitted_type = _safe_get(payload, "error_type", "")
442
  expected_type = ground_truth.get("error_type", "syntax")
443
  type_score = _score_error_type(str(submitted_type), expected_type)
444
  score += type_score
445
  breakdown["error_type"] = round(type_score, 4)
446
+ if type_score > 0: feedback_parts.append("Correctly identified error type.")
 
447
 
 
448
  explanation = _safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "")
449
  expl_score = _score_explanation(str(explanation) if explanation else "")
450
  score += expl_score
451
  breakdown["explanation"] = round(expl_score, 4)
452
+ if expl_score > 0: feedback_parts.append("Explanation provided.")
 
453
 
 
454
  confidence = _safe_get(payload, "confidence", None)
455
  conf_score = _score_confidence(confidence)
456
  score += conf_score
457
  breakdown["confidence"] = round(conf_score, 4)
458
 
459
+ final_score = round(max(0.001, min(0.999, score)), 4)
460
  feedback = " ".join(feedback_parts) if feedback_parts else "No valid response provided."
461
  return final_score, breakdown, feedback
462
 
463
 
464
  def grade_medium(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
 
 
 
 
 
465
  if action is None or action.payload is None:
466
+ return 0.001, {"error": "null_action"}, "No action provided."
467
 
468
  payload = action.payload
469
  score = 0.0
470
  breakdown = {}
471
  feedback_parts = []
472
 
 
473
  submitted_query = _safe_get(payload, "fixed_query", "") or _safe_get(payload, "optimized_query", "")
474
  expected_query = ground_truth.get("fixed_query", "")
475
  similarity = _query_similarity(submitted_query, expected_query)
476
 
477
  if similarity >= 1.0:
478
+ fix_score = 0.40; feedback_parts.append("Correct fix applied.")
 
479
  elif similarity >= 0.80:
480
+ fix_score = 0.28; feedback_parts.append("Fix is mostly correct.")
 
481
  elif similarity >= 0.60:
482
+ fix_score = 0.16; feedback_parts.append("Fix is partially correct.")
 
483
  elif similarity >= 0.40:
484
+ fix_score = 0.08; feedback_parts.append("Fix shows some understanding.")
 
485
  else:
486
+ fix_score = 0.0; feedback_parts.append("Fix is incorrect or missing.")
 
487
 
488
  score += fix_score
489
  breakdown["fix_correctness"] = round(fix_score, 4)
490
 
 
491
  explanation = str(_safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "") or "")
492
  error_type = ground_truth.get("error_type", "logic")
493
 
 
496
  "aggregate", "subquery", "correlation", "distinct", "count"],
497
  "performance": ["index", "scan", "n+1", "correlated", "cartesian", "window"]
498
  }
 
499
  keywords_to_check = logic_keywords.get(error_type, logic_keywords["logic"])
500
  expl_lower = explanation.lower()
501
  keyword_hits = sum(1 for kw in keywords_to_check if kw in expl_lower)
502
  logic_score = min(keyword_hits * 0.05, 0.20)
503
  score += logic_score
504
  breakdown["logic_flaw_identification"] = round(logic_score, 4)
505
+ if logic_score > 0: feedback_parts.append("Shows understanding of the logic flaw.")
 
506
 
 
507
  submitted_location = _safe_get(payload, "error_location", "")
508
  expected_location = ground_truth.get("error_location", "")
509
  loc_score = _score_error_location(str(submitted_location), expected_location)
510
  score += loc_score
511
  breakdown["error_location"] = round(loc_score, 4)
512
 
 
513
  expl_score = _score_explanation(explanation)
514
  score += expl_score
515
  breakdown["explanation"] = round(expl_score, 4)
516
 
 
517
  confidence = _safe_get(payload, "confidence", None)
518
  conf_score = _score_confidence(confidence)
519
  score += conf_score
520
  breakdown["confidence"] = round(conf_score, 4)
521
 
 
522
  impact = str(_safe_get(payload, "impact", "") or "")
523
  if len(impact.strip()) > 20:
524
  score += 0.05
 
527
  else:
528
  breakdown["impact_analysis"] = 0.0
529
 
530
+ final_score = round(max(0.001, min(0.999, score)), 4)
531
  feedback = " ".join(feedback_parts) if feedback_parts else "No valid response provided."
532
  return final_score, breakdown, feedback
533
 
534
 
535
  def grade_hard(action: Action, ground_truth: dict) -> tuple[float, dict, str]:
 
 
 
 
 
 
536
  if action is None or action.payload is None:
537
+ return 0.001, {"error": "null_action"}, "No action provided."
538
 
539
  payload = action.payload
540
  score = 0.0
541
  breakdown = {}
542
  feedback_parts = []
543
 
 
544
  submitted_query = (
545
  _safe_get(payload, "optimized_query", "")
546
  or _safe_get(payload, "fixed_query", "")
 
550
  similarity = _query_similarity(submitted_query, expected_query)
551
 
552
  if similarity >= 1.0:
553
+ fix_score = 0.30; feedback_parts.append("Perfectly optimized query.")
 
554
  elif similarity >= 0.85:
555
+ fix_score = 0.22; feedback_parts.append("Query is mostly correct.")
 
556
  elif similarity >= 0.65:
557
+ fix_score = 0.14; feedback_parts.append("Query shows correct approach but incomplete.")
 
558
  elif similarity >= 0.40:
559
+ fix_score = 0.07; feedback_parts.append("Query partially addresses the issue.")
 
560
  else:
561
+ fix_score = 0.0; feedback_parts.append("Query does not address the performance issue.")
 
562
 
563
  score += fix_score
564
  breakdown["query_correctness"] = round(fix_score, 4)
565
 
 
566
  explanation = str(_safe_get(payload, "explanation", "") or _safe_get(payload, "change_made", "") or "")
567
  optimization = str(_safe_get(payload, "optimization_type", "") or "")
568
  combined_text = (explanation + " " + optimization).lower()
 
586
 
587
  score += concept_score
588
  breakdown["performance_concept"] = round(concept_score, 4)
589
+ if concept_score > 0: feedback_parts.append("Demonstrates understanding of the performance issue.")
 
590
 
 
591
  expl_score = _score_explanation(explanation)
592
  if len(explanation.strip()) > 150:
593
  expl_score = min(expl_score + 0.05, 0.15)
594
  score += expl_score
595
  breakdown["explanation_depth"] = round(expl_score, 4)
596
 
 
597
  root_cause = str(_safe_get(payload, "root_cause", "") or "")
598
  if len(root_cause.strip()) > 30:
599
  score += 0.10
 
602
  else:
603
  breakdown["root_cause_analysis"] = 0.0
604
 
 
605
  improvement = str(_safe_get(payload, "expected_improvement", "") or "")
606
  if len(improvement.strip()) > 20:
607
  score += 0.10
 
610
  else:
611
  breakdown["expected_improvement"] = 0.0
612
 
 
613
  confidence = _safe_get(payload, "confidence", None)
614
  conf_score = _score_confidence(confidence)
615
  score += conf_score
616
  breakdown["confidence"] = round(conf_score, 4)
617
 
618
+ final_score = round(max(0.001, min(0.999, score)), 4)
619
  feedback = " ".join(feedback_parts) if feedback_parts else "Performance issue not identified."
620
  return final_score, breakdown, feedback
621
 
 
627
  def grade(action: Action, task_id: str) -> tuple[float, dict, str]:
628
  """
629
  Main grader entry point.
630
+
631
+ ROUTING:
632
+ Round 2 scenario IDs (easy_s001, medium_s002, hard_s003)
633
+ β†’ grade_db_action() ← NEW: scores DB engineering actions
634
+
635
+ Round 1 task IDs (easy_001, medium_001, hard_001)
636
+ β†’ grade_easy/medium/hard() ← unchanged
637
+
638
+ ALWAYS returns (float, dict, str). NEVER crashes.
639
+ Score always strictly between 0.001 and 0.999.
640
  """
641
  if action is None:
642
+ return 0.001, {"error": "null_action"}, "No action provided."
643
+
644
+ # ── Round 2: DB engineering scenario ─────────────────────────
645
+ if _is_scenario_task(task_id):
646
+ return grade_db_action(action, task_id)
647
 
648
+ # ── Round 1: SQL debugging task ───────────────────────────────
649
  ground_truth = task_manager.get_ground_truth(task_id)
650
  if ground_truth is None:
651
+ return 0.001, {"error": "unknown_task"}, f"Task '{task_id}' not found."
652
 
653
+ difficulty = task_id.split("_")[0]
654
 
655
  try:
656
  if difficulty == "easy":
 
660
  elif difficulty == "hard":
661
  return grade_hard(action, ground_truth)
662
  else:
663
+ return 0.001, {"error": "unknown_difficulty"}, f"Unknown difficulty: {difficulty}"
664
  except Exception as e:
665
+ return 0.001, {"error": str(e)}, f"Grader error: {str(e)}"