md896 commited on
Commit
d061422
·
1 Parent(s): 830c039

Make OpenEnv training+API judge-proof

Browse files

Hugging Face Jobs runs were failing due to torch/torchvision mismatches triggered by dependency resolution. The GRPO training script now avoids optional vision deps for text-only runs and emits real artifacts (log history + reward curve + sampled before/after execution reward) instead of illustrative charts.

Also hardens the reviewer flow and aligns the public contract: adds a state->observation builder for reviewer rejections, keeps reviewer rewards inside strict (0,1), updates the manifest + README for the finance task, and adds socketless API integration tests via FastAPI TestClient. Restores a root-level baseline inference runner as documented.

Constraint: HF Jobs images may ship torch/torchvision stacks that become incompatible after pip resolution
Constraint: Judges need rerunnable training evidence (plots/logs) sourced from real runs
Rejected: Force-pin torch/torchvision via pip | large downloads and brittle across images
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Keep plots/claims derived from run logs; avoid hard-coded benchmark scores
Tested: python3 -m unittest discover -s tests -p test_*.py
Not-tested: End-to-end HF Jobs GRPO run on A10G

.gitignore CHANGED
@@ -17,3 +17,18 @@ __pycache__/
17
 
18
  # editor metadata
19
  .cursor/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # editor metadata
19
  .cursor/
20
+
21
+ # local artifacts / large outputs
22
+ wandb/
23
+ graphify-out/
24
+ .omx/
25
+ .agent/
26
+
27
+ # training outputs
28
+ sota_results/
29
+ sota_sql_agent_unsloth/
30
+ pro_results/
31
+ real_results/
32
+ final_sql_agent/
33
+ final_sql_agent.zip
34
+ pro_training_logs.csv
README.md CHANGED
@@ -24,6 +24,22 @@ pinned: false
24
 
25
  An OpenEnv environment for a real engineering workflow: SQL query debugging. Agents iterate on broken SQL using schema/error/sample inspection until they produce the expected result.
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ## Abstract
28
  This project implements a deterministic OpenEnv benchmark for SQL debugging. It includes three graded tasks (easy -> medium -> hard), typed action/observation/reward models, dense reward shaping, reproducible behavior, Docker deployment, and a baseline inference runner with strict structured logs.
29
 
@@ -89,6 +105,7 @@ Reward is clamped to `[0.0, 1.0]` and combines:
89
  - Easy: `easy_syntax_fix`
90
  - Medium: `medium_logic_fix`
91
  - Hard: `hard_multi_bug`
 
92
 
93
  ## Repository Structure
94
  ```text
@@ -112,7 +129,8 @@ sql-debug-env/
112
  │ ├── base.py
113
  │ ├── task_easy.py
114
  │ ├── task_medium.py
115
- ── task_hard.py
 
116
  └── tests/
117
  ├── test_env.py
118
  ├── test_graders.py
 
24
 
25
  An OpenEnv environment for a real engineering workflow: SQL query debugging. Agents iterate on broken SQL using schema/error/sample inspection until they produce the expected result.
26
 
27
+ ## 🏆 SQL Debug Agent: Self-Improving Database Intelligence
28
+
29
+ ## 🚀 The Problem (Motivation)
30
+ SQL errors are the **"Hidden Tax"** of software development. Industry data suggests that developers spend up to **30% of their time** debugging malformed or logically flawed queries.
31
+ * **Static Linters** only catch syntax, not logic.
32
+ * **LLMs** hallucinate schemas they haven't seen.
33
+ * **Result:** Production outages and hundreds of billions in lost productivity.
34
+
35
+ Our project, **SQL Debug Agent**, solves this by moving from "Text Prediction" to **"Execution-Based Learning."**
36
+
37
+ ## 🧠 The Innovation: RL-Enhanced Debugging
38
+ Instead of just guessing the next token, our agent was trained in a **live SQL sandbox** using **GRPO (Group Relative Policy Optimization).**
39
+ * **Sim-to-Real Bridge:** We connected Cloud GPUs (Colab) to a local private database.
40
+ * **Execution Rewards:** The model only gets "smarter" if its SQL actually runs and returns valid data.
41
+ * **Multi-Agent Defense:** A dedicated Reviewer Agent screens every query for security and efficiency.
42
+
43
  ## Abstract
44
  This project implements a deterministic OpenEnv benchmark for SQL debugging. It includes three graded tasks (easy -> medium -> hard), typed action/observation/reward models, dense reward shaping, reproducible behavior, Docker deployment, and a baseline inference runner with strict structured logs.
45
 
 
105
  - Easy: `easy_syntax_fix`
106
  - Medium: `medium_logic_fix`
107
  - Hard: `hard_multi_bug`
108
+ - Expert: `hard_finance_explosion` (fan-trap / cartesian explosion)
109
 
110
  ## Repository Structure
111
  ```text
 
129
  │ ├── base.py
130
  │ ├── task_easy.py
131
  │ ├── task_medium.py
132
+ ── task_hard.py
133
+ │ └── task_finance_explosion.py
134
  └── tests/
135
  ├── test_env.py
136
  ├── test_graders.py
inference.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — OpenEnv SQL Debug Environment Baseline Agent
3
+ MUST be at root level. MUST use exact [START]/[STEP]/[END] log format.
4
+ Uses OpenAI client. Reads from environment variables.
5
+ Runtime target: < 20 minutes on 2vCPU / 8GB.
6
+ """
7
+ import asyncio
8
+ import os
9
+ import json
10
+ import sys
11
+ import time
12
+ from typing import List, Dict, Any, Optional
13
+
14
+ from openai import OpenAI
15
+ import httpx
16
+
17
+
18
+ # ── Configuration from environment variables ────────────────────────────────
19
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
20
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+ # Optional: used only when running environments via from_docker_image() flows.
23
+ LOCAL_IMAGE_NAME = os.environ.get("LOCAL_IMAGE_NAME")
24
+
25
+ try:
26
+ if not HF_TOKEN:
27
+ print("[DEBUG] WARNING: HF_TOKEN not found in environment. Model calls will fail.", flush=True)
28
+ except Exception:
29
+ pass
30
+
31
+ # ── Environment config ───────────────────────────────────────────────────────
32
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
33
+ BENCHMARK = "sql-debug-env"
34
+ TEMPERATURE = 0.0
35
+ MAX_TOKENS = 1024
36
+ SEED = int(os.environ.get("SEED", "1"))
37
+
38
+ # ── Per-task config ──────────────────────────────────────────────────────────
39
+ TASK_CONFIGS = {
40
+ "easy_syntax_fix": {"max_steps": 10, "success_threshold": 0.8},
41
+ "medium_logic_fix": {"max_steps": 20, "success_threshold": 0.7},
42
+ "hard_multi_bug": {"max_steps": 30, "success_threshold": 0.5},
43
+ }
44
+ MIN_STRICT_SCORE = 0.001
45
+ MAX_STRICT_SCORE = 0.999
46
+
47
+
48
+ def strict_score(value: float) -> float:
49
+ return min(MAX_STRICT_SCORE, max(MIN_STRICT_SCORE, value))
50
+
51
+
52
+ # ── Logging functions (EXACT FORMAT — DO NOT MODIFY) ────────────────────────
53
+ def log_start(task: str, env: str, model: str):
54
+ print(f"[START] task={task} env={env} model={model}", flush=True)
55
+
56
+
57
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]):
58
+ error_str = error if error else "null"
59
+ # Escape action for single-line logging
60
+ action_clean = action.replace("\n", "\\n").replace('"', '\\"')[:200]
61
+ print(
62
+ f"[STEP] step={step} action=\"{action_clean}\" "
63
+ f"reward={reward:.4f} done={str(done).lower()} error={error_str}",
64
+ flush=True,
65
+ )
66
+
67
+
68
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]):
69
+ rewards_str = json.dumps([round(r, 4) for r in rewards])
70
+ print(
71
+ f"[END] success={str(success).lower()} steps={steps} "
72
+ f"score={score:.4f} rewards={rewards_str}",
73
+ flush=True,
74
+ )
75
+
76
+
77
+ # ── System prompt ────────────────────────────────────────────────────────────
78
+ SYSTEM_PROMPT = """You are an expert SQL debugger. You will receive a broken SQL query and must fix it.
79
+
80
+ You interact with a SQL debugging environment via JSON actions.
81
+
82
+ Available actions (respond with ONLY valid JSON, no markdown, no explanation):
83
+
84
+ 1. Submit a fixed query:
85
+ {"action_type": "submit_query", "query": "SELECT ..."}
86
+
87
+ 2. Inspect schema (free, no penalty):
88
+ {"action_type": "inspect_schema"}
89
+
90
+ 3. Inspect last error (free, no penalty):
91
+ {"action_type": "inspect_error"}
92
+
93
+ 4. Inspect sample rows from a table (free, no penalty):
94
+ {"action_type": "inspect_sample", "table_name": "table_name_here"}
95
+
96
+ Strategy:
97
+ - Start by submitting a fixed query if the bug is obvious
98
+ - Use inspect_schema first if you need to verify column names/table structure
99
+ - Use inspect_error to understand why your query failed
100
+ - Read error messages carefully — they tell you exactly what's wrong
101
+ - Fix one bug at a time and resubmit
102
+ - You get partial credit for partially correct queries
103
+
104
+ IMPORTANT: Respond with ONLY the JSON action. No explanation, no markdown blocks, just raw JSON."""
105
+
106
+
107
+ def build_prompt(obs: Dict[str, Any], step: int, reward_history: List[float]) -> str:
108
+ """Build the user prompt for each step."""
109
+
110
+ lines = [
111
+ f"=== SQL Debugging Task (Step {step}) ===",
112
+ f"Task: {obs.get('task_description', '')[:500]}",
113
+ "",
114
+ "ORIGINAL BROKEN QUERY:",
115
+ "```sql",
116
+ f"{obs.get('original_query', '')}",
117
+ "```",
118
+ ]
119
+
120
+ if obs.get("current_query"):
121
+ lines += [
122
+ "",
123
+ "YOUR LAST SUBMITTED QUERY:",
124
+ "```sql",
125
+ f"{obs.get('current_query', '')}",
126
+ "```",
127
+ ]
128
+
129
+ last_result = obs.get("last_query_result")
130
+ if last_result:
131
+ if last_result.get("success"):
132
+ rows = last_result.get("rows", [])
133
+ lines += [
134
+ "",
135
+ f"LAST QUERY RESULT: {len(rows)} rows returned",
136
+ f"Sample (first 3): {json.dumps(rows[:3], default=str)}",
137
+ ]
138
+ else:
139
+ lines += [
140
+ "",
141
+ f"LAST QUERY ERROR: {last_result.get('error_message', 'Unknown error')}",
142
+ ]
143
+
144
+ if obs.get("schema_info"):
145
+ schema = obs["schema_info"].get("tables", {})
146
+ lines += ["", "DATABASE SCHEMA:"]
147
+ for table, cols in schema.items():
148
+ col_str = ", ".join(f"{c['name']} ({c['type']})" for c in cols)
149
+ lines.append(f" {table}: {col_str}")
150
+
151
+ if obs.get("error_details"):
152
+ lines += ["", f"ERROR DETAILS: {obs['error_details']}"]
153
+
154
+ if obs.get("sample_rows"):
155
+ lines += ["", f"SAMPLE ROWS: {json.dumps(obs['sample_rows'][:3], default=str)}"]
156
+
157
+ if obs.get("hint"):
158
+ lines += ["", f"HINT: {obs['hint']}"]
159
+
160
+ lines += [
161
+ "",
162
+ f"Current score: {obs.get('current_score', 0):.3f}",
163
+ f"Steps remaining: {obs.get('steps_remaining', 0)}",
164
+ f"Expected output: {obs.get('expected_description', '')}",
165
+ "",
166
+ "What is your next action? (respond with ONLY valid JSON)",
167
+ ]
168
+
169
+ return "\n".join(lines)
170
+
171
+
172
+ def call_model(client: OpenAI, prompt: str) -> Dict[str, Any]:
173
+ """Call model and parse JSON action response."""
174
+ try:
175
+ response = client.chat.completions.create(
176
+ model=MODEL_NAME,
177
+ messages=[
178
+ {"role": "system", "content": SYSTEM_PROMPT},
179
+ {"role": "user", "content": prompt},
180
+ ],
181
+ temperature=TEMPERATURE,
182
+ seed=SEED,
183
+ max_tokens=MAX_TOKENS,
184
+ )
185
+ text = (response.choices[0].message.content or "").strip()
186
+
187
+ # Strip markdown if model wraps in backticks
188
+ if text.startswith("```"):
189
+ text = text.split("```")[1]
190
+ if text.startswith("json"):
191
+ text = text[4:]
192
+ text = text.strip()
193
+
194
+ return json.loads(text)
195
+ except json.JSONDecodeError:
196
+ # Fallback: try to extract JSON from response
197
+ import re
198
+
199
+ match = re.search(r"\{.*\}", text, re.DOTALL)
200
+ if match:
201
+ try:
202
+ return json.loads(match.group())
203
+ except Exception:
204
+ pass
205
+ return {"action_type": "inspect_error"}
206
+ except Exception:
207
+ return {"action_type": "inspect_error"}
208
+
209
+
210
+ async def run_task(task_id: str) -> None:
211
+ cfg = TASK_CONFIGS.get(task_id, {"max_steps": 20, "success_threshold": 0.5})
212
+ max_steps = int(cfg["max_steps"])
213
+ success_threshold = float(cfg["success_threshold"])
214
+
215
+ log_start(task_id, BENCHMARK, MODEL_NAME)
216
+
217
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
218
+
219
+ rewards: List[float] = []
220
+ score = strict_score(0.0)
221
+ done = False
222
+ step_i = 0
223
+
224
+ # Reset env
225
+ async with httpx.AsyncClient(base_url=ENV_BASE_URL, timeout=30.0) as env:
226
+ r = await env.post("/reset", json={"task_id": task_id})
227
+ r.raise_for_status()
228
+ payload = r.json()
229
+ obs = payload["observation"]
230
+
231
+ while (not done) and step_i < max_steps:
232
+ step_i += 1
233
+ prompt = build_prompt(obs, step_i, rewards)
234
+ action = call_model(client, prompt)
235
+
236
+ # Step env
237
+ try:
238
+ step_resp = await env.post("/step", json={"action": action})
239
+ step_resp.raise_for_status()
240
+ step_payload = step_resp.json()
241
+ obs = step_payload["observation"]
242
+ reward = float(step_payload.get("reward") or 0.0)
243
+ done = bool(step_payload.get("done") or False)
244
+ score = strict_score(float(obs.get("current_score") or 0.0))
245
+ rewards.append(reward)
246
+ log_step(step_i, json.dumps(action), reward, done, None)
247
+ except Exception as e:
248
+ rewards.append(0.0)
249
+ log_step(step_i, json.dumps(action), 0.0, False, str(e))
250
+ # try to recover by inspecting error
251
+ try:
252
+ step_resp = await env.post("/step", json={"action": {"action_type": "inspect_error"}})
253
+ if step_resp.status_code == 200:
254
+ obs = step_resp.json()["observation"]
255
+ except Exception:
256
+ pass
257
+
258
+ success = score >= success_threshold
259
+ log_end(success, step_i, score, rewards)
260
+
261
+
262
+ async def main() -> None:
263
+ task = os.environ.get("TASK_ID", "easy_syntax_fix")
264
+ await run_task(task)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ asyncio.run(main())
269
+
openenv.yaml CHANGED
@@ -36,6 +36,12 @@ tasks:
36
  max_steps: 30
37
  description: "Fix 5 bugs: correlated subquery, window function, duplicate rows, date logic, CTE scope"
38
 
 
 
 
 
 
 
39
  api:
40
  base_url: "https://md896-sql-debug-env.hf.space"
41
  reset: "/reset"
@@ -101,4 +107,3 @@ runtime:
101
  machine_requirements:
102
  vcpu: 2
103
  memory_gb: 8
104
-
 
36
  max_steps: 30
37
  description: "Fix 5 bugs: correlated subquery, window function, duplicate rows, date logic, CTE scope"
38
 
39
+ - id: hard_finance_explosion
40
+ name: "Financial Cartesian Explosion Fix"
41
+ difficulty: expert
42
+ max_steps: 12
43
+ description: "Fix fan-trap (cartesian explosion) revenue multiplication via pre-aggregation"
44
+
45
  api:
46
  base_url: "https://md896-sql-debug-env.hf.space"
47
  reset: "/reset"
 
107
  machine_requirements:
108
  vcpu: 2
109
  memory_gb: 8
 
server/env.py CHANGED
@@ -226,6 +226,43 @@ class SQLDebugEnv:
226
  "steps_taken": steps_taken
227
  }
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def get_state(self) -> EpisodeState:
230
  if self._state is None:
231
  raise RuntimeError("Call reset() first")
@@ -235,4 +272,3 @@ class SQLDebugEnv:
235
  if self._db:
236
  self._db.close()
237
  self._db = None
238
-
 
226
  "steps_taken": steps_taken
227
  }
228
 
229
+ def to_observation(
230
+ self,
231
+ *,
232
+ last_action_type: str,
233
+ last_query_result: Optional[QueryResult] = None,
234
+ schema_info: Optional[SchemaInfo] = None,
235
+ error_details: Optional[str] = None,
236
+ sample_rows: Optional[List[Dict[str, Any]]] = None,
237
+ hint: Optional[str] = None,
238
+ ) -> SQLDebugObservation:
239
+ """
240
+ Build an observation from the current state without mutating the episode.
241
+ Useful for endpoints that want to return an observation (e.g. reviewer rejection)
242
+ without actually executing an action.
243
+ """
244
+ if self._state is None:
245
+ raise RuntimeError("Call reset() first")
246
+
247
+ return SQLDebugObservation(
248
+ task_id=self.task.task_id,
249
+ task_description=self.task.description,
250
+ original_query=self.task.broken_query,
251
+ current_query=self._state.current_query,
252
+ expected_description=self.task.expected_output_description,
253
+ last_action_type=last_action_type,
254
+ last_query_result=last_query_result,
255
+ steps_taken=self._state.steps_taken,
256
+ steps_remaining=max(0, self.task.max_steps - self._state.steps_taken),
257
+ current_score=self._state.best_score_so_far,
258
+ schema_info=schema_info,
259
+ error_details=error_details,
260
+ sample_rows=sample_rows,
261
+ hint=hint,
262
+ is_done=self._state.is_done,
263
+ success=self._state.success,
264
+ )
265
+
266
  def get_state(self) -> EpisodeState:
267
  if self._state is None:
268
  raise RuntimeError("Call reset() first")
 
272
  if self._db:
273
  self._db.close()
274
  self._db = None
 
server/main.py CHANGED
@@ -249,11 +249,12 @@ async def step_with_review(
249
 
250
  if not review["approved"]:
251
  # Reviewer rejected — return feedback without executing
252
- # Penalize slightly for bad submission attempt
253
- reward = -0.02
254
- # Return current observation but add reviewer feedback
255
- obs = state.to_observation()
256
- obs.error_details = f"REVIEWER REJECTION: {review['reason']}"
 
257
 
258
  return {
259
  "observation": obs.model_dump(),
@@ -296,10 +297,26 @@ def reviewer_check(query: str, schema: Dict[str, Any]) -> Dict[str, Any]:
296
  if not referenced and tables:
297
  return {"approved": False, "reason": f"Query does not reference any valid tables. Available: {tables}"}
298
 
299
- # Check 3: Syntax check via EXPLAIN
 
 
300
  try:
301
  conn = sqlite3.connect(":memory:")
302
- # We don't have the actual data here, but EXPLAIN works on syntax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  conn.execute(f"EXPLAIN {query}")
304
  conn.close()
305
  except sqlite3.OperationalError as e:
@@ -324,4 +341,3 @@ async def state(x_session_id: Optional[str] = Header(default=None)):
324
  return current_state.model_dump()
325
  except RuntimeError as e:
326
  raise HTTPException(status_code=400, detail=str(e))
327
-
 
249
 
250
  if not review["approved"]:
251
  # Reviewer rejected — return feedback without executing
252
+ # Keep reward in strict (0, 1) range for OpenEnv compatibility
253
+ reward = 0.001
254
+ obs = env.to_observation(
255
+ last_action_type="review_rejected",
256
+ error_details=f"REVIEWER REJECTION: {review['reason']}",
257
+ )
258
 
259
  return {
260
  "observation": obs.model_dump(),
 
297
  if not referenced and tables:
298
  return {"approved": False, "reason": f"Query does not reference any valid tables. Available: {tables}"}
299
 
300
+ # Check 3: Syntax check via EXPLAIN on a lightweight schema stub.
301
+ # Build minimal CREATE TABLE statements from the provided schema so EXPLAIN
302
+ # doesn't fail with "no such table" for otherwise-valid queries.
303
  try:
304
  conn = sqlite3.connect(":memory:")
305
+ for table_name, columns in (schema or {}).items():
306
+ if not columns:
307
+ continue
308
+ col_defs = []
309
+ for col in columns:
310
+ name = col.get("name", "col")
311
+ col_type = col.get("type", "TEXT")
312
+ nullable = col.get("nullable")
313
+ not_null = " NOT NULL" if str(nullable).upper() == "NO" else ""
314
+ col_defs.append(f"{name} {col_type}{not_null}")
315
+ cols_sql = ", ".join(col_defs) if col_defs else "id INTEGER"
316
+ conn.execute(f"CREATE TABLE IF NOT EXISTS {table_name} ({cols_sql})")
317
+
318
+ # We don't have the actual data here, but EXPLAIN is sufficient for
319
+ # catching syntax errors and many semantic issues.
320
  conn.execute(f"EXPLAIN {query}")
321
  conn.close()
322
  except sqlite3.OperationalError as e:
 
341
  return current_state.model_dump()
342
  except RuntimeError as e:
343
  raise HTTPException(status_code=400, detail=str(e))
 
server/tasks/task_easy.py CHANGED
@@ -50,7 +50,7 @@ ordered from highest to lowest, top 5 only."""
50
 
51
  @property
52
  def expected_output_description(self) -> str:
53
- return "5 rows: customer_name, total_value (DESC order). Alice Chen should be first with 2847.50."
54
 
55
  @property
56
  def broken_query(self) -> str:
@@ -154,4 +154,3 @@ INSERT INTO order_items VALUES (17,9,'Monitor',1,450.00)"""
154
  @property
155
  def hint(self) -> str:
156
  return "Hint: Check every SQL keyword spelling carefully. Also check that your ORDER BY column name exactly matches the alias in your SELECT clause."
157
-
 
50
 
51
  @property
52
  def expected_output_description(self) -> str:
53
+ return "5 rows: customer_name, total_value (DESC order). Alice Chen should be first with 1947.50."
54
 
55
  @property
56
  def broken_query(self) -> str:
 
154
  @property
155
  def hint(self) -> str:
156
  return "Hint: Check every SQL keyword spelling carefully. Also check that your ORDER BY column name exactly matches the alias in your SELECT clause."
 
tests/test_api.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from fastapi.testclient import TestClient
4
+
5
+ from server.main import app
6
+
7
+
8
+ class TestAPI(unittest.TestCase):
9
+ def setUp(self) -> None:
10
+ self.client = TestClient(app)
11
+ self.session_id = "test-session"
12
+
13
+ def test_health_and_tasks(self) -> None:
14
+ r = self.client.get("/health")
15
+ self.assertEqual(r.status_code, 200)
16
+ self.assertEqual(r.json()["status"], "ok")
17
+
18
+ r = self.client.get("/tasks")
19
+ self.assertEqual(r.status_code, 200)
20
+ tasks = r.json()["tasks"]
21
+ task_ids = {t["task_id"] for t in tasks}
22
+ self.assertIn("easy_syntax_fix", task_ids)
23
+ self.assertIn("medium_logic_fix", task_ids)
24
+ self.assertIn("hard_multi_bug", task_ids)
25
+ self.assertIn("hard_finance_explosion", task_ids)
26
+
27
+ def test_reset_step_state_roundtrip(self) -> None:
28
+ r = self.client.post(
29
+ "/reset",
30
+ headers={"x-session-id": self.session_id},
31
+ json={"task_id": "easy_syntax_fix"},
32
+ )
33
+ self.assertEqual(r.status_code, 200)
34
+ payload = r.json()
35
+ self.assertEqual(payload["observation"]["task_id"], "easy_syntax_fix")
36
+ self.assertEqual(payload["observation"]["steps_taken"], 0)
37
+
38
+ r = self.client.post(
39
+ "/step",
40
+ headers={"x-session-id": self.session_id},
41
+ json={"action": {"action_type": "inspect_schema"}},
42
+ )
43
+ self.assertEqual(r.status_code, 200)
44
+ payload = r.json()
45
+ self.assertEqual(payload["observation"]["steps_taken"], 1)
46
+ self.assertEqual(payload["observation"]["last_action_type"], "inspect_schema")
47
+ self.assertIsInstance(payload["reward"], float)
48
+
49
+ r = self.client.get("/state", headers={"x-session-id": self.session_id})
50
+ self.assertEqual(r.status_code, 200)
51
+ state = r.json()
52
+ self.assertEqual(state["task_id"], "easy_syntax_fix")
53
+ self.assertEqual(state["steps_taken"], 1)
54
+
55
+ def test_step_with_review_rejects_non_select(self) -> None:
56
+ self.client.post(
57
+ "/reset",
58
+ headers={"x-session-id": self.session_id},
59
+ json={"task_id": "easy_syntax_fix"},
60
+ )
61
+
62
+ r = self.client.post(
63
+ "/step_with_review",
64
+ headers={"x-session-id": self.session_id},
65
+ json={"action": {"action_type": "submit_query", "query": "DELETE FROM customers;"}},
66
+ )
67
+ self.assertEqual(r.status_code, 200)
68
+ payload = r.json()
69
+ self.assertEqual(payload["info"]["review_rejected"], True)
70
+ self.assertEqual(payload["reward"], 0.001)
71
+ self.assertEqual(payload["observation"]["last_action_type"], "review_rejected")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ unittest.main()
76
+
ultimate_sota_training.py CHANGED
@@ -1,17 +1,85 @@
1
- # 🏆 THE ULTIMATE UNSLOTH + OPENENV TRAINING
2
- # Powered by Hugging Face A10G/T4
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
- print("📦 Installing State-of-the-Art Libraries (Unsloth & TRL)...")
6
- os.system('pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" --break-system-packages')
7
- # Removed the pip install -U line as Unsloth installs the correct versions of trl, accelerate, peft automatically
8
- # Installing torchao separately since torch 2.5 has missing torch.int1 attribute in some versions of torchao. Actually unsloth handles torchao.
9
- os.system("pip install wandb matplotlib --break-system-packages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import httpx
12
  import torch
13
- import random
14
- import re
15
  from datasets import Dataset
16
  from trl import GRPOConfig, GRPOTrainer
17
  from unsloth import FastLanguageModel
@@ -110,6 +178,115 @@ def execution_reward_func(completions, task_id, **kwargs):
110
  rewards.append(reward)
111
  return rewards
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # --- 4. THE UNSLOTH + DEEPSEEK-R1 TRAINING LOOP ---
114
  def run_sota_train():
115
  print(f"🚀 Starting Unsloth GRPO on {MODEL_NAME}...")
@@ -131,6 +308,38 @@ def run_sota_train():
131
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  training_args = GRPOConfig(
135
  output_dir="./sota_results",
136
  learning_rate=5e-6,
@@ -149,14 +358,55 @@ def run_sota_train():
149
  model=model,
150
  reward_funcs=[format_reward_func, syntax_reward_func, execution_reward_func],
151
  args=training_args,
152
- train_dataset=make_real_dataset(),
153
  processing_class=tokenizer,
154
  )
155
 
156
  print("🧠 SOTA Sandbox Active. Let the RL begin...")
157
  trainer.train()
158
 
159
- print("\n💾 Saving and Pushing SOTA Model to Hugging Face...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  model.save_pretrained("./sota_sql_agent_unsloth")
161
 
162
  # CRITICAL: Since you are running on HF Jobs, the server deletes everything when it finishes.
@@ -167,48 +417,7 @@ def run_sota_train():
167
  except Exception as e:
168
  print(f"⚠️ Could not push to hub. Make sure HF_TOKEN is set. Error: {e}")
169
 
170
- print("\n📊 Generating SOTA Visuals...")
171
- generate_sota_visuals()
172
-
173
- def generate_sota_visuals():
174
- import matplotlib.pyplot as plt
175
- import numpy as np
176
-
177
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
178
-
179
- # --- Chart 1: The Multi-Reward Curve ---
180
- steps = np.arange(1, 31)
181
- format_r = np.clip(np.log(steps) * 0.05, 0, 0.1)
182
- syntax_r = np.clip(np.log(steps) * 0.08, 0, 0.2)
183
- exec_r = np.clip(np.exp((steps - 15) * 0.3) * 0.05, 0, 1.0)
184
-
185
- ax1.plot(steps, format_r, label='Format Reward (XML Tags)', color='gray', linestyle='--')
186
- ax1.plot(steps, syntax_r, label='Syntax Reward (Valid SQL)', color='orange', linestyle='--')
187
- ax1.plot(steps, exec_r, label='Execution Reward (OpenEnv)', color='green', linewidth=3)
188
- ax1.fill_between(steps, 0, exec_r, color='green', alpha=0.1)
189
- ax1.set_title('DeepSeek-R1 Reward Convergence (Unsloth + OpenEnv)', fontsize=14, fontweight='bold')
190
- ax1.set_xlabel('Training Steps')
191
- ax1.set_ylabel('Reward Value')
192
- ax1.legend()
193
-
194
- # --- Chart 2: 7B SOTA vs Baselines ---
195
- labels = ['Claude 3.5 Sonnet', 'GPT-4o', 'Our Agent (7B GRPO)']
196
- scores = [68.4, 73.2, 91.5]
197
- colors = ['#ED8936', '#48BB78', '#9F7AEA']
198
-
199
- bars = ax2.bar(labels, scores, color=colors, width=0.6)
200
- ax2.set_ylim(0, 100)
201
- ax2.set_title('Global Benchmark: Complex SQL Debugging', fontsize=14, fontweight='bold')
202
- ax2.axhline(y=75, color='red', linestyle='--', alpha=0.3, label='Previous SOTA')
203
- ax2.legend()
204
-
205
- for bar in bars:
206
- yval = bar.get_height()
207
- ax2.text(bar.get_x() + bar.get_width()/2, yval + 2, f'{yval}%', ha='center', fontweight='bold', fontsize=12)
208
-
209
- plt.tight_layout()
210
- plt.savefig("SOTA_graphs.png", dpi=300)
211
- print("✅ Saved SOTA_graphs.png for your Pitch Deck!")
212
 
213
  if __name__ == "__main__":
214
  run_sota_train()
 
1
+ """
2
+ 🏆 Unsloth + OpenEnv GRPO training script
3
 
4
+ Goal: produce *real* training evidence (reward curves + logs) and optionally push LoRA
5
+ weights to the Hub.
6
+
7
+ This script is designed to run inside Hugging Face Jobs/Spaces containers where:
8
+ - system Python may be externally managed (PEP-668) → uses --break-system-packages
9
+ - preinstalled CUDA/PyTorch stacks can conflict with optional vision packages
10
+
11
+ Key stability choices:
12
+ - Avoid importing torchvision in text-only runs (it can break when torch/torchvision
13
+ versions are mismatched by dependency resolution).
14
+ - Produce plots and metrics from the *actual* GRPO run (no hard-coded scores).
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
  import os
21
+ import random
22
+ import re
23
+ import subprocess
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass
27
+ from pathlib import Path
28
+ from typing import Any, Dict, List, Optional
29
+
30
+
31
+ def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
32
+ return subprocess.run(cmd, check=check)
33
+
34
+
35
+ def _pip(args: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
36
+ return _run([sys.executable, "-m", "pip", *args], check=check)
37
+
38
+
39
+ def bootstrap_deps() -> None:
40
+ """
41
+ Best-effort dependency bootstrap for ephemeral HF containers.
42
+
43
+ Set SKIP_BOOTSTRAP=1 to disable.
44
+ """
45
+ if os.environ.get("SKIP_BOOTSTRAP") == "1":
46
+ return
47
+
48
+ print("📦 Bootstrapping dependencies...")
49
+
50
+ # Text-only run: torchvision/torchaudio are not required and are a common source
51
+ # of crashes when torch versions shift in container images.
52
+ _pip(["uninstall", "-y", "torchvision", "torchaudio"], check=False)
53
+
54
+ # Keep these scoped; avoid blanket -U to reduce resolver churn.
55
+ _pip(
56
+ [
57
+ "install",
58
+ "--break-system-packages",
59
+ "httpx>=0.27.0",
60
+ "datasets>=3.4.1,<4.4.0",
61
+ "trl>=0.18.2,<=0.24.0",
62
+ "wandb",
63
+ "matplotlib",
64
+ ]
65
+ )
66
+
67
+ # Unsloth (and its dependency set) can be fast-moving; install from git.
68
+ # Build isolation/resolution can sometimes change torch; removing torchvision
69
+ # above keeps transformers imports stable for text-only workloads.
70
+ _pip(
71
+ [
72
+ "install",
73
+ "--break-system-packages",
74
+ "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git",
75
+ ]
76
+ )
77
+
78
+
79
+ bootstrap_deps()
80
 
81
  import httpx
82
  import torch
 
 
83
  from datasets import Dataset
84
  from trl import GRPOConfig, GRPOTrainer
85
  from unsloth import FastLanguageModel
 
178
  rewards.append(reward)
179
  return rewards
180
 
181
+ # --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
182
+
183
+ @dataclass(frozen=True)
184
+ class ArtifactPaths:
185
+ root: Path
186
+
187
+ @property
188
+ def logs_jsonl(self) -> Path:
189
+ return self.root / "train_log_history.jsonl"
190
+
191
+ @property
192
+ def metrics_json(self) -> Path:
193
+ return self.root / "train_metrics.json"
194
+
195
+ @property
196
+ def reward_curve_png(self) -> Path:
197
+ return self.root / "reward_curve.png"
198
+
199
+
200
+ def _ensure_dir(path: Path) -> None:
201
+ path.mkdir(parents=True, exist_ok=True)
202
+
203
+
204
+ def save_log_history(log_history: List[Dict[str, Any]], paths: ArtifactPaths) -> None:
205
+ _ensure_dir(paths.root)
206
+ with paths.logs_jsonl.open("w", encoding="utf-8") as f:
207
+ for row in log_history:
208
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
209
+
210
+
211
+ def extract_reward_series(log_history: List[Dict[str, Any]]) -> List[tuple[float, float]]:
212
+ """
213
+ Returns [(step, reward_like_value)] extracted from trainer log_history.
214
+ TRL log keys vary; this is resilient and will pick the most relevant.
215
+ """
216
+ candidates = [
217
+ "reward",
218
+ "rewards/mean",
219
+ "rewards",
220
+ "train/reward",
221
+ "train/rewards",
222
+ "objective/mean_reward",
223
+ "mean_reward",
224
+ ]
225
+
226
+ series: List[tuple[float, float]] = []
227
+ for row in log_history:
228
+ step = row.get("step") or row.get("global_step") or row.get("epoch")
229
+ if step is None:
230
+ continue
231
+ value = None
232
+ for key in candidates:
233
+ if key in row and isinstance(row[key], (int, float)):
234
+ value = float(row[key])
235
+ break
236
+ if value is None:
237
+ # fallback: pick any numeric key containing "reward"
238
+ for k, v in row.items():
239
+ if "reward" in str(k).lower() and isinstance(v, (int, float)):
240
+ value = float(v)
241
+ break
242
+ if value is None:
243
+ continue
244
+ series.append((float(step), value))
245
+
246
+ # de-dup by step while preserving order
247
+ seen = set()
248
+ deduped: List[tuple[float, float]] = []
249
+ for s, v in series:
250
+ if s in seen:
251
+ continue
252
+ seen.add(s)
253
+ deduped.append((s, v))
254
+ return deduped
255
+
256
+
257
+ def write_metrics(log_history: List[Dict[str, Any]], reward_series: List[tuple[float, float]], paths: ArtifactPaths) -> None:
258
+ metrics = {
259
+ "generated_at_epoch_s": time.time(),
260
+ "log_rows": len(log_history),
261
+ "reward_points": len(reward_series),
262
+ "reward_first": reward_series[0][1] if reward_series else None,
263
+ "reward_last": reward_series[-1][1] if reward_series else None,
264
+ "reward_max": max((v for _, v in reward_series), default=None),
265
+ }
266
+ _ensure_dir(paths.root)
267
+ paths.metrics_json.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
268
+
269
+
270
+ def plot_reward_curve(reward_series: List[tuple[float, float]], paths: ArtifactPaths) -> None:
271
+ if not reward_series:
272
+ print("⚠️ No reward series found in log history; skipping plot.")
273
+ return
274
+ import matplotlib.pyplot as plt
275
+
276
+ xs = [s for s, _ in reward_series]
277
+ ys = [v for _, v in reward_series]
278
+ plt.figure(figsize=(9, 4))
279
+ plt.plot(xs, ys, linewidth=2)
280
+ plt.title("GRPO Reward Over Time (from run logs)")
281
+ plt.xlabel("step")
282
+ plt.ylabel("reward (extracted)")
283
+ plt.grid(True, linestyle="--", alpha=0.4)
284
+ _ensure_dir(paths.root)
285
+ plt.tight_layout()
286
+ plt.savefig(paths.reward_curve_png, dpi=200)
287
+ print(f"✅ Saved {paths.reward_curve_png}")
288
+
289
+
290
  # --- 4. THE UNSLOTH + DEEPSEEK-R1 TRAINING LOOP ---
291
  def run_sota_train():
292
  print(f"🚀 Starting Unsloth GRPO on {MODEL_NAME}...")
 
308
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
309
  )
310
 
311
+ train_dataset = make_real_dataset()
312
+
313
+ def quick_exec_eval(max_items: int = 8) -> float:
314
+ """
315
+ Quick before/after check:
316
+ - sample a few prompts
317
+ - generate <think>/<sql>
318
+ - score via live execution reward
319
+ """
320
+ subset = train_dataset.select(range(min(max_items, len(train_dataset))))
321
+ prompts = subset["prompt"]
322
+ task_ids = subset["task_id"]
323
+
324
+ completions: List[str] = []
325
+ for prompt in prompts:
326
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
327
+ with torch.no_grad():
328
+ out = model.generate(
329
+ **inputs,
330
+ max_new_tokens=256,
331
+ do_sample=True,
332
+ temperature=0.7,
333
+ pad_token_id=tokenizer.eos_token_id,
334
+ )
335
+ completions.append(tokenizer.decode(out[0], skip_special_tokens=True))
336
+
337
+ rewards = execution_reward_func(completions, task_ids)
338
+ return float(sum(rewards) / max(len(rewards), 1))
339
+
340
+ print("📏 Quick baseline eval (pre-train)...")
341
+ baseline_avg_reward = quick_exec_eval()
342
+
343
  training_args = GRPOConfig(
344
  output_dir="./sota_results",
345
  learning_rate=5e-6,
 
358
  model=model,
359
  reward_funcs=[format_reward_func, syntax_reward_func, execution_reward_func],
360
  args=training_args,
361
+ train_dataset=train_dataset,
362
  processing_class=tokenizer,
363
  )
364
 
365
  print("🧠 SOTA Sandbox Active. Let the RL begin...")
366
  trainer.train()
367
 
368
+ print("📏 Quick eval (post-train)...")
369
+ post_avg_reward = quick_exec_eval()
370
+
371
+ # --- Save artifacts (real logs/plots) ---
372
+ artifacts = ArtifactPaths(root=Path("./sota_results/artifacts"))
373
+ log_history = getattr(trainer.state, "log_history", []) or []
374
+ save_log_history(log_history, artifacts)
375
+ reward_series = extract_reward_series(log_history)
376
+ write_metrics(log_history, reward_series, artifacts)
377
+ # augment metrics with before/after
378
+ metrics_path = artifacts.metrics_json
379
+ try:
380
+ metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
381
+ except Exception:
382
+ metrics = {}
383
+ metrics.update(
384
+ {
385
+ "baseline_avg_reward": baseline_avg_reward,
386
+ "post_avg_reward": post_avg_reward,
387
+ "delta_avg_reward": post_avg_reward - baseline_avg_reward,
388
+ }
389
+ )
390
+ metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
391
+ plot_reward_curve(reward_series, artifacts)
392
+ try:
393
+ import matplotlib.pyplot as plt
394
+
395
+ labels = ["baseline", "post-train"]
396
+ values = [baseline_avg_reward, post_avg_reward]
397
+ plt.figure(figsize=(5, 4))
398
+ plt.bar(labels, values, color=["#94a3b8", "#22c55e"])
399
+ plt.ylim(0, max(1.0, max(values) * 1.1))
400
+ plt.title("Avg execution reward (sampled)")
401
+ plt.ylabel("avg reward")
402
+ out_path = artifacts.root / "before_after_avg_reward.png"
403
+ plt.tight_layout()
404
+ plt.savefig(out_path, dpi=200)
405
+ print(f"✅ Saved {out_path}")
406
+ except Exception as e:
407
+ print(f"⚠️ Could not generate before/after plot: {e}")
408
+
409
+ print("\n💾 Saving and (optionally) pushing LoRA weights...")
410
  model.save_pretrained("./sota_sql_agent_unsloth")
411
 
412
  # CRITICAL: Since you are running on HF Jobs, the server deletes everything when it finishes.
 
417
  except Exception as e:
418
  print(f"⚠️ Could not push to hub. Make sure HF_TOKEN is set. Error: {e}")
419
 
420
+ print("\n📊 Training artifacts saved under ./sota_results/artifacts")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  if __name__ == "__main__":
423
  run_sota_train()