junaid0600 commited on
Commit
f5f1b7a
Β·
1 Parent(s): ea504bf
Files changed (1) hide show
  1. inference.py +239 -57
inference.py CHANGED
@@ -1,91 +1,273 @@
 
 
 
 
 
 
1
  import os
2
- from dotenv import load_dotenv
3
- load_dotenv()
 
4
 
5
  from openai import OpenAI
6
 
7
- # ── Environment variables ──────────────────────────────────────────
 
 
 
 
 
 
8
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
9
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
10
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
11
 
12
- if not HF_TOKEN:
13
- raise ValueError("HF_TOKEN environment variable is required")
 
14
 
15
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
16
- BENCHMARK = "sql-query-debugger"
17
 
18
- # ── MONKEY-PATCH must happen BEFORE importing baseline ────────────
19
- # The grader reads reward.score from env.step() directly.
20
- # We wrap step() so reward.score is always strictly in (0, 1).
21
- from env.environment import SQLDebuggerEnvironment
22
 
23
- _original_step = SQLDebuggerEnvironment.step
 
 
 
24
 
25
- def _patched_step(self, action):
26
- result = _original_step(self, action)
27
- if hasattr(result, "reward") and hasattr(result.reward, "score"):
28
- raw = float(result.reward.score)
29
- result.reward.score = round(max(0.001, min(0.999, raw)), 4)
30
- return result
31
 
32
- SQLDebuggerEnvironment.step = _patched_step
33
- print("[DEBUG] SQLDebuggerEnvironment.step patched successfully", flush=True)
 
34
 
35
- # ── NOW safe to import baseline ───────────────────────────────────
36
- from baseline import run_baseline
37
 
38
- # ── Logging helpers ───────────────────────────────────────────────
39
- def log_start(task, env, model):
40
- print(f"[START] task={task} env={env} model={model}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def log_step(step, action, reward, done, error=None):
43
- print(
44
- f"[STEP] step={step} action={action} reward={reward:.4f} "
45
- f"done={str(done).lower()} error={error or 'null'}",
46
- flush=True
47
- )
48
 
49
- def log_end(success, steps, rewards):
50
- rewards_str = ",".join(f"{r:.4f}" for r in rewards)
51
- print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # ── Mandatory LLM call ────────────────────────────────────────────
54
- def call_llm(prompt: str) -> str:
55
  try:
56
  completion = client.chat.completions.create(
57
  model=MODEL_NAME,
58
- messages=[{"role": "user", "content": prompt}],
 
 
 
59
  temperature=0.3,
60
- max_tokens=100,
 
61
  )
62
- return (completion.choices[0].message.content or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
- print(f"[DEBUG] LLM call failed: {e}", flush=True)
65
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # ── Main ──────────────────────────────────────────────────────────
68
  def main():
 
69
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
70
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
71
 
72
- llm_response = call_llm("Fix this SQL query: SELECT id name FROM users WHERE")
73
- print(f"[DEBUG] LLM response: {llm_response[:80]}", flush=True)
74
 
75
- response = run_baseline()
 
 
 
 
76
 
77
- all_rewards = []
78
- for r in response.results:
79
- score = round(max(0.001, min(0.999, float(r.score))), 4)
80
- all_rewards.append(score)
81
 
82
- log_start(task=r.task_id, env=BENCHMARK, model=MODEL_NAME)
83
- log_step(step=1, action="submit_answer", reward=score, done=True)
84
- log_end(success=score > 0.5, steps=1, rewards=[score])
85
- print(f"[DEBUG] task={r.task_id} final_score={score}", flush=True)
 
86
 
87
- avg = sum(all_rewards) / len(all_rewards) if all_rewards else 0.5
88
- print(f"\n[DEBUG] Average Score: {avg:.4f}", flush=True)
89
 
90
  if __name__ == "__main__":
91
  main()
 
1
+ """
2
+ inference.py β€” SQL Query Debugger OpenEnv
3
+ Follows the mandatory [START]/[STEP]/[END] stdout format.
4
+ Uses OpenAI client with API_BASE_URL, MODEL_NAME, HF_TOKEN.
5
+ """
6
+
7
  import os
8
+ import json
9
+ import textwrap
10
+ from typing import List, Optional
11
 
12
  from openai import OpenAI
13
 
14
+ from env.environment import SQLDebuggerEnvironment
15
+ from env.models import Action, ActionType, DifficultyLevel
16
+
17
+ # ─────────────────────────────────────────────
18
+ # ENVIRONMENT VARIABLES
19
+ # ─────────────────────────────────────────────
20
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or "dummy-key"
21
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
22
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
23
+ BENCHMARK = "sql-query-debugger"
24
+ MAX_STEPS = 10
25
+ SUCCESS_SCORE_THRESHOLD = 0.5
26
 
27
+ # ─────────────────────────────────────────────
28
+ # LOGGING FUNCTIONS β€” exact format required
29
+ # ─────────────────────────────────────────────
30
 
31
+ def log_start(task: str, env: str, model: str) -> None:
32
+ print(f"[START] task={task} env={env} model={model}", flush=True)
33
 
 
 
 
 
34
 
35
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
36
+ error_val = error if error else "null"
37
+ done_val = str(done).lower()
38
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
39
 
 
 
 
 
 
 
40
 
41
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
42
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
43
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
44
 
 
 
45
 
46
+ # ─────────────────────────────────────────────
47
+ # SYSTEM PROMPT
48
+ # ─────────────────────────────────────────────
49
+
50
+ SYSTEM_PROMPT = textwrap.dedent("""
51
+ You are an expert SQL debugger. You will be given a buggy SQL query and must fix it.
52
+
53
+ You must respond with a JSON object only β€” no explanation outside the JSON.
54
+
55
+ For syntax/logic errors, respond with:
56
+ {
57
+ "action_type": "submit_answer",
58
+ "fixed_query": "<your fixed SQL query here>",
59
+ "explanation": "<brief explanation of what was wrong>",
60
+ "error_type": "<syntax|logic|performance>",
61
+ "error_location": "<where in the query the error is>",
62
+ "confidence": 0.9
63
+ }
64
+
65
+ For performance issues, respond with:
66
+ {
67
+ "action_type": "optimize_query",
68
+ "optimized_query": "<your optimized SQL query here>",
69
+ "optimization_type": "<what optimization was applied>",
70
+ "explanation": "<why this optimization works>",
71
+ "root_cause": "<what caused the performance issue>",
72
+ "expected_improvement": "<expected performance gain>",
73
+ "confidence": 0.85
74
+ }
75
+
76
+ Always provide valid JSON. Never include markdown code blocks.
77
+ """).strip()
78
 
 
 
 
 
 
 
79
 
80
+ def build_user_prompt(obs) -> str:
81
+ ctx = obs.current_context
82
+ return textwrap.dedent(f"""
83
+ Task: {obs.task_description}
84
+ Difficulty: {obs.difficulty}
85
+
86
+ Buggy Query:
87
+ {ctx.get('buggy_query', 'N/A')}
88
+
89
+ Error Message:
90
+ {ctx.get('error_message', 'N/A')}
91
+
92
+ Database Schema:
93
+ {json.dumps(ctx.get('database_schema', {}), indent=2)}
94
+
95
+ Error Type Hint: {ctx.get('error_type_hint', 'unknown')}
96
+ Category: {ctx.get('category', 'unknown')}
97
+ Steps Remaining: {ctx.get('steps_remaining', 20)}
98
+
99
+ Analyze the buggy query and provide your fix as a JSON object.
100
+ """).strip()
101
+
102
+
103
+ # ─────────────────────────────────────────────
104
+ # LLM CALL
105
+ # ─────────────────────────────────────────────
106
+
107
+ def get_llm_action(client: OpenAI, obs, step: int) -> Action:
108
+ """Call the LLM and parse its response into an Action."""
109
+ user_prompt = build_user_prompt(obs)
110
 
 
 
111
  try:
112
  completion = client.chat.completions.create(
113
  model=MODEL_NAME,
114
+ messages=[
115
+ {"role": "system", "content": SYSTEM_PROMPT},
116
+ {"role": "user", "content": user_prompt},
117
+ ],
118
  temperature=0.3,
119
+ max_tokens=512,
120
+ stream=False,
121
  )
122
+ text = (completion.choices[0].message.content or "").strip()
123
+
124
+ # Parse JSON response
125
+ # Remove markdown code blocks if present
126
+ if "```" in text:
127
+ text = text.split("```")[1]
128
+ if text.startswith("json"):
129
+ text = text[4:]
130
+ text = text.strip()
131
+
132
+ data = json.loads(text)
133
+ action_type = data.get("action_type", "submit_answer")
134
+
135
+ if action_type == "optimize_query":
136
+ return Action(
137
+ action_type=ActionType.OPTIMIZE_QUERY,
138
+ payload={
139
+ "optimized_query": data.get("optimized_query", "SELECT 1"),
140
+ "optimization_type": data.get("optimization_type", "Performance fix"),
141
+ "explanation": data.get("explanation", ""),
142
+ "root_cause": data.get("root_cause", ""),
143
+ "expected_improvement": data.get("expected_improvement", ""),
144
+ "confidence": float(data.get("confidence", 0.7)),
145
+ }
146
+ )
147
+ else:
148
+ return Action(
149
+ action_type=ActionType.SUBMIT_ANSWER,
150
+ payload={
151
+ "fixed_query": data.get("fixed_query", "SELECT 1"),
152
+ "explanation": data.get("explanation", ""),
153
+ "error_type": data.get("error_type", "syntax"),
154
+ "error_location": data.get("error_location", "unknown"),
155
+ "confidence": float(data.get("confidence", 0.7)),
156
+ }
157
+ )
158
+
159
+ except Exception as exc:
160
+ print(f"[DEBUG] LLM call failed: {exc}", flush=True)
161
+ # Fallback to identify_error action
162
+ return Action(
163
+ action_type=ActionType.IDENTIFY_ERROR,
164
+ payload={
165
+ "error_location": "unknown",
166
+ "error_type": "syntax",
167
+ "explanation": "LLM call failed, using fallback"
168
+ }
169
+ )
170
+
171
+
172
+ # ─────────────────────────────────────────────
173
+ # MAIN INFERENCE LOOP
174
+ # ─────────────────────────────────────────────
175
+
176
+ def run_episode(client: OpenAI, difficulty: str, task_id: str) -> dict:
177
+ """Run one full episode and return results."""
178
+ env = SQLDebuggerEnvironment()
179
+ obs = env.reset(difficulty=difficulty, task_id=task_id)
180
+ rewards = []
181
+ steps = 0
182
+ success = False
183
+ score = 0.0
184
+
185
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
186
+
187
+ try:
188
+ for step in range(1, MAX_STEPS + 1):
189
+ if env.state().done:
190
+ break
191
+
192
+ # Get action from LLM
193
+ action = get_llm_action(client, obs, step)
194
+ action_str = f"{action.action_type.value}"
195
+ error_str = None
196
+
197
+ try:
198
+ resp = env.step(action)
199
+ reward = resp.reward.score
200
+ done = resp.done
201
+ obs = resp.observation
202
+ except Exception as e:
203
+ reward = -0.1
204
+ done = False
205
+ error_str = str(e)[:100]
206
+
207
+ rewards.append(reward)
208
+ steps = step
209
+
210
+ log_step(
211
+ step = step,
212
+ action = action_str,
213
+ reward = reward,
214
+ done = done,
215
+ error = error_str
216
+ )
217
+
218
+ if done:
219
+ break
220
+
221
+ # Calculate score
222
+ total_reward = sum(rewards)
223
+ score = min(max(total_reward / MAX_STEPS, 0.0), 1.0)
224
+ success = score >= SUCCESS_SCORE_THRESHOLD
225
+
226
  except Exception as e:
227
+ print(f"[DEBUG] Episode error: {e}", flush=True)
228
+ error_str = str(e)[:100]
229
+
230
+ finally:
231
+ log_end(
232
+ success = success,
233
+ steps = steps,
234
+ score = score,
235
+ rewards = rewards
236
+ )
237
+
238
+ return {
239
+ "task_id": task_id,
240
+ "difficulty": difficulty,
241
+ "score": score,
242
+ "steps": steps,
243
+ "success": success,
244
+ }
245
+
246
 
 
247
  def main():
248
+ """Main entry point β€” runs inference on all 3 difficulty levels."""
249
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
250
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
251
 
252
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
253
 
254
+ tasks = [
255
+ ("easy", "easy_001"),
256
+ ("medium", "medium_001"),
257
+ ("hard", "hard_001"),
258
+ ]
259
 
260
+ results = []
261
+ for difficulty, task_id in tasks:
262
+ result = run_episode(client, difficulty, task_id)
263
+ results.append(result)
264
 
265
+ # Final summary
266
+ avg_score = sum(r["score"] for r in results) / len(results)
267
+ print(f"\n[DEBUG] Average Score: {avg_score:.3f}", flush=True)
268
+ for r in results:
269
+ print(f"[DEBUG] {r['difficulty']:8} | {r['task_id']:12} | score={r['score']:.3f} | steps={r['steps']}", flush=True)
270
 
 
 
271
 
272
  if __name__ == "__main__":
273
  main()