junaid0600 commited on
Commit
5447299
Β·
1 Parent(s): 04e5467

Update inference.py with [START]/[STEP]/[END] format and dotenv loading

Browse files
Files changed (1) hide show
  1. inference.py +269 -10
inference.py CHANGED
@@ -1,16 +1,275 @@
1
- # inference.py β€” required by OpenEnv validator
2
- from baseline import run_baseline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Entry point for openenv validator
5
- app = None # placeholder for server entry point
6
 
7
  def main():
8
- print("Running inference / baseline agent...")
9
- response = run_baseline()
10
- print(f"\nInference Results:")
11
- for r in response.results:
12
- print(f" {r.difficulty.value:8} | {r.task_id:12} | score={r.score} | steps={r.steps}")
13
- print(f"\nAverage Score: {response.average_score}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
  main()
 
1
+ """
2
+ inference.py β€” SQL Query Debugger OpenEnv
3
+ Follows the mandatory [START]/[STEP]/[END] stdout format.
4
+ Uses OpenAI client with API_BASE_URL, MODEL_NAME, HF_TOKEN.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import textwrap
10
+ from typing import List, Optional
11
+
12
+ from openai import OpenAI
13
+ from dotenv import load_dotenv
14
+ load_dotenv()
15
+
16
+ from env.environment import SQLDebuggerEnvironment
17
+ from env.models import Action, ActionType, DifficultyLevel
18
+
19
+ # ─────────────────────────────────────────────
20
+ # ENVIRONMENT VARIABLES
21
+ # ─────────────────────────────────────────────
22
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or "dummy-key"
23
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
24
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
25
+ BENCHMARK = "sql-query-debugger"
26
+ MAX_STEPS = 10
27
+ SUCCESS_SCORE_THRESHOLD = 0.5
28
+
29
+ # ─────────────────────────────────────────────
30
+ # LOGGING FUNCTIONS β€” exact format required
31
+ # ─────────────────────────────────────────────
32
+
33
+ def log_start(task: str, env: str, model: str) -> None:
34
+ print(f"[START] task={task} env={env} model={model}", flush=True)
35
+
36
+
37
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
38
+ error_val = error if error else "null"
39
+ done_val = str(done).lower()
40
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
41
+
42
+
43
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
44
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
45
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
46
+
47
+
48
+ # ─────────────────────────────────────────────
49
+ # SYSTEM PROMPT
50
+ # ─────────────────────────────────────────────
51
+
52
+ SYSTEM_PROMPT = textwrap.dedent("""
53
+ You are an expert SQL debugger. You will be given a buggy SQL query and must fix it.
54
+
55
+ You must respond with a JSON object only β€” no explanation outside the JSON.
56
+
57
+ For syntax/logic errors, respond with:
58
+ {
59
+ "action_type": "submit_answer",
60
+ "fixed_query": "<your fixed SQL query here>",
61
+ "explanation": "<brief explanation of what was wrong>",
62
+ "error_type": "<syntax|logic|performance>",
63
+ "error_location": "<where in the query the error is>",
64
+ "confidence": 0.9
65
+ }
66
+
67
+ For performance issues, respond with:
68
+ {
69
+ "action_type": "optimize_query",
70
+ "optimized_query": "<your optimized SQL query here>",
71
+ "optimization_type": "<what optimization was applied>",
72
+ "explanation": "<why this optimization works>",
73
+ "root_cause": "<what caused the performance issue>",
74
+ "expected_improvement": "<expected performance gain>",
75
+ "confidence": 0.85
76
+ }
77
+
78
+ Always provide valid JSON. Never include markdown code blocks.
79
+ """).strip()
80
+
81
+
82
+ def build_user_prompt(obs) -> str:
83
+ ctx = obs.current_context
84
+ return textwrap.dedent(f"""
85
+ Task: {obs.task_description}
86
+ Difficulty: {obs.difficulty}
87
+
88
+ Buggy Query:
89
+ {ctx.get('buggy_query', 'N/A')}
90
+
91
+ Error Message:
92
+ {ctx.get('error_message', 'N/A')}
93
+
94
+ Database Schema:
95
+ {json.dumps(ctx.get('database_schema', {}), indent=2)}
96
+
97
+ Error Type Hint: {ctx.get('error_type_hint', 'unknown')}
98
+ Category: {ctx.get('category', 'unknown')}
99
+ Steps Remaining: {ctx.get('steps_remaining', 20)}
100
+
101
+ Analyze the buggy query and provide your fix as a JSON object.
102
+ """).strip()
103
+
104
+
105
+ # ─────────────────────────────────────────────
106
+ # LLM CALL
107
+ # ─────────────────────────────────────────────
108
+
109
+ def get_llm_action(client: OpenAI, obs, step: int) -> Action:
110
+ """Call the LLM and parse its response into an Action."""
111
+ user_prompt = build_user_prompt(obs)
112
+
113
+ try:
114
+ completion = client.chat.completions.create(
115
+ model=MODEL_NAME,
116
+ messages=[
117
+ {"role": "system", "content": SYSTEM_PROMPT},
118
+ {"role": "user", "content": user_prompt},
119
+ ],
120
+ temperature=0.3,
121
+ max_tokens=512,
122
+ stream=False,
123
+ )
124
+ text = (completion.choices[0].message.content or "").strip()
125
+
126
+ # Parse JSON response
127
+ # Remove markdown code blocks if present
128
+ if "```" in text:
129
+ text = text.split("```")[1]
130
+ if text.startswith("json"):
131
+ text = text[4:]
132
+ text = text.strip()
133
+
134
+ data = json.loads(text)
135
+ action_type = data.get("action_type", "submit_answer")
136
+
137
+ if action_type == "optimize_query":
138
+ return Action(
139
+ action_type=ActionType.OPTIMIZE_QUERY,
140
+ payload={
141
+ "optimized_query": data.get("optimized_query", "SELECT 1"),
142
+ "optimization_type": data.get("optimization_type", "Performance fix"),
143
+ "explanation": data.get("explanation", ""),
144
+ "root_cause": data.get("root_cause", ""),
145
+ "expected_improvement": data.get("expected_improvement", ""),
146
+ "confidence": float(data.get("confidence", 0.7)),
147
+ }
148
+ )
149
+ else:
150
+ return Action(
151
+ action_type=ActionType.SUBMIT_ANSWER,
152
+ payload={
153
+ "fixed_query": data.get("fixed_query", "SELECT 1"),
154
+ "explanation": data.get("explanation", ""),
155
+ "error_type": data.get("error_type", "syntax"),
156
+ "error_location": data.get("error_location", "unknown"),
157
+ "confidence": float(data.get("confidence", 0.7)),
158
+ }
159
+ )
160
+
161
+ except Exception as exc:
162
+ print(f"[DEBUG] LLM call failed: {exc}", flush=True)
163
+ # Fallback to identify_error action
164
+ return Action(
165
+ action_type=ActionType.IDENTIFY_ERROR,
166
+ payload={
167
+ "error_location": "unknown",
168
+ "error_type": "syntax",
169
+ "explanation": "LLM call failed, using fallback"
170
+ }
171
+ )
172
+
173
+
174
+ # ─────────────────────────────────────────────
175
+ # MAIN INFERENCE LOOP
176
+ # ─────────────────────────────────────────────
177
+
178
+ def run_episode(client: OpenAI, difficulty: str, task_id: str) -> dict:
179
+ """Run one full episode and return results."""
180
+ env = SQLDebuggerEnvironment()
181
+ obs = env.reset(difficulty=difficulty, task_id=task_id)
182
+ rewards = []
183
+ steps = 0
184
+ success = False
185
+ score = 0.0
186
+
187
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
188
+
189
+ try:
190
+ for step in range(1, MAX_STEPS + 1):
191
+ if env.state().done:
192
+ break
193
+
194
+ # Get action from LLM
195
+ action = get_llm_action(client, obs, step)
196
+ action_str = f"{action.action_type.value}"
197
+ error_str = None
198
+
199
+ try:
200
+ resp = env.step(action)
201
+ reward = resp.reward.score
202
+ done = resp.done
203
+ obs = resp.observation
204
+ except Exception as e:
205
+ reward = -0.1
206
+ done = False
207
+ error_str = str(e)[:100]
208
+
209
+ rewards.append(reward)
210
+ steps = step
211
+
212
+ log_step(
213
+ step = step,
214
+ action = action_str,
215
+ reward = reward,
216
+ done = done,
217
+ error = error_str
218
+ )
219
+
220
+ if done:
221
+ break
222
+
223
+ # Calculate score
224
+ total_reward = sum(rewards)
225
+ score = min(max(total_reward / MAX_STEPS, 0.0), 1.0)
226
+ success = score >= SUCCESS_SCORE_THRESHOLD
227
+
228
+ except Exception as e:
229
+ print(f"[DEBUG] Episode error: {e}", flush=True)
230
+ error_str = str(e)[:100]
231
+
232
+ finally:
233
+ log_end(
234
+ success = success,
235
+ steps = steps,
236
+ score = score,
237
+ rewards = rewards
238
+ )
239
+
240
+ return {
241
+ "task_id": task_id,
242
+ "difficulty": difficulty,
243
+ "score": score,
244
+ "steps": steps,
245
+ "success": success,
246
+ }
247
 
 
 
248
 
249
  def main():
250
+ """Main entry point β€” runs inference on all 3 difficulty levels."""
251
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
252
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
253
+
254
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
255
+
256
+ tasks = [
257
+ ("easy", "easy_001"),
258
+ ("medium", "medium_001"),
259
+ ("hard", "hard_001"),
260
+ ]
261
+
262
+ results = []
263
+ for difficulty, task_id in tasks:
264
+ result = run_episode(client, difficulty, task_id)
265
+ results.append(result)
266
+
267
+ # Final summary
268
+ avg_score = sum(r["score"] for r in results) / len(results)
269
+ print(f"\n[DEBUG] Average Score: {avg_score:.3f}", flush=True)
270
+ for r in results:
271
+ print(f"[DEBUG] {r['difficulty']:8} | {r['task_id']:12} | score={r['score']:.3f} | steps={r['steps']}", flush=True)
272
+
273
 
274
  if __name__ == "__main__":
275
  main()