junaid0600 commited on
Commit
d8cba4f
Β·
1 Parent(s): b02ec3c

Use real LLM calls through API_BASE_URL proxy

Browse files
Files changed (1) hide show
  1. inference.py +138 -15
inference.py CHANGED
@@ -1,8 +1,13 @@
1
  import os
 
 
 
2
  from dotenv import load_dotenv
3
  load_dotenv()
4
 
5
  from openai import OpenAI
 
 
6
 
7
  # ── Required environment variables ──────────────
8
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
@@ -12,28 +17,146 @@ HF_TOKEN = os.getenv("HF_TOKEN")
12
  if HF_TOKEN is None:
13
  raise ValueError("HF_TOKEN environment variable is required")
14
 
15
- # ── Initialize OpenAI client (required by hackathon rules) ──
16
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
 
17
 
18
- # ── Import baseline ──────────────────────────────
19
- from baseline import run_baseline
20
 
 
 
21
 
22
- def main():
23
- print(f"[DEBUG] API_BASE_URL={API_BASE_URL}")
24
- print(f"[DEBUG] MODEL_NAME={MODEL_NAME}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- response = run_baseline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- for r in response.results:
29
- # Ensure score strictly between 0 and 1 exclusive
30
- score = max(0.001, min(0.999, float(r.score)))
31
- print(f"[START] task={r.task_id} env=sql-query-debugger model={MODEL_NAME}")
32
- print(f"[STEP] step=1 action=submit_answer reward={score:.2f} done=true error=null")
33
- print(f"[END] success=true steps=1 rewards={score:.2f}")
34
 
35
- print(f"\n[DEBUG] Average Score: {response.average_score:.3f}")
 
 
 
36
 
 
 
37
 
38
  if __name__ == "__main__":
39
  main()
 
1
  import os
2
+ import json
3
+ import textwrap
4
+ from typing import List, Optional
5
  from dotenv import load_dotenv
6
  load_dotenv()
7
 
8
  from openai import OpenAI
9
+ from env.environment import SQLDebuggerEnvironment
10
+ from env.models import Action, ActionType
11
 
12
  # ── Required environment variables ──────────────
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
 
17
  if HF_TOKEN is None:
18
  raise ValueError("HF_TOKEN environment variable is required")
19
 
20
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
21
+ BENCHMARK = "sql-query-debugger"
22
+ MAX_STEPS = 5
23
 
24
+ SYSTEM_PROMPT = """You are an expert SQL debugger. Given a buggy SQL query, respond with ONLY a JSON object.
 
25
 
26
+ For syntax/logic errors:
27
+ {"action_type":"submit_answer","fixed_query":"<fixed SQL>","explanation":"<what was wrong>","error_type":"syntax","confidence":0.9}
28
 
29
+ For performance issues:
30
+ {"action_type":"optimize_query","optimized_query":"<optimized SQL>","optimization_type":"<what was optimized>","explanation":"<why>","root_cause":"<cause>","expected_improvement":"<improvement>","confidence":0.85}
31
+
32
+ Never include markdown. Only valid JSON."""
33
+
34
+ def log_start(task, env, model):
35
+ print(f"[START] task={task} env={env} model={model}", flush=True)
36
+
37
+ def log_step(step, action, reward, done, error=None):
38
+ error_val = error if error else "null"
39
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
40
+
41
+ def log_end(success, steps, rewards):
42
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
43
+ print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
44
+
45
+ def get_llm_action(obs) -> Action:
46
+ ctx = obs.current_context
47
+ prompt = f"""Task: {obs.task_description}
48
+ Buggy Query: {ctx.get('buggy_query','N/A')}
49
+ Error: {ctx.get('error_message','N/A')}
50
+ Schema: {json.dumps(ctx.get('database_schema',{}))}
51
+ Category: {ctx.get('category','syntax')}
52
+ Fix this SQL query and respond with JSON only."""
53
+
54
+ try:
55
+ completion = client.chat.completions.create(
56
+ model=MODEL_NAME,
57
+ messages=[
58
+ {"role": "system", "content": SYSTEM_PROMPT},
59
+ {"role": "user", "content": prompt}
60
+ ],
61
+ temperature=0.3,
62
+ max_tokens=512,
63
+ )
64
+ text = (completion.choices[0].message.content or "").strip()
65
+ if "```" in text:
66
+ text = text.split("```")[1]
67
+ if text.startswith("json"):
68
+ text = text[4:]
69
+ text = text.strip()
70
+ data = json.loads(text)
71
+
72
+ if data.get("action_type") == "optimize_query":
73
+ return Action(action_type=ActionType.OPTIMIZE_QUERY, payload={
74
+ "optimized_query": data.get("optimized_query", "SELECT 1"),
75
+ "optimization_type": data.get("optimization_type", "fix"),
76
+ "explanation": data.get("explanation", ""),
77
+ "root_cause": data.get("root_cause", ""),
78
+ "expected_improvement": data.get("expected_improvement", ""),
79
+ "confidence": float(data.get("confidence", 0.7)),
80
+ })
81
+ else:
82
+ return Action(action_type=ActionType.SUBMIT_ANSWER, payload={
83
+ "fixed_query": data.get("fixed_query", "SELECT 1"),
84
+ "explanation": data.get("explanation", ""),
85
+ "error_type": data.get("error_type", "syntax"),
86
+ "error_location": data.get("error_location", "unknown"),
87
+ "confidence": float(data.get("confidence", 0.7)),
88
+ })
89
+ except Exception as e:
90
+ print(f"[DEBUG] LLM failed: {e}", flush=True)
91
+ return Action(action_type=ActionType.IDENTIFY_ERROR, payload={
92
+ "error_location": "unknown",
93
+ "error_type": "syntax",
94
+ "explanation": "fallback"
95
+ })
96
 
97
+ def run_episode(difficulty, task_id):
98
+ env = SQLDebuggerEnvironment()
99
+ obs = env.reset(difficulty=difficulty, task_id=task_id)
100
+ rewards = []
101
+ steps = 0
102
+ success = False
103
+
104
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
105
+
106
+ try:
107
+ for step in range(1, MAX_STEPS + 1):
108
+ if env.state().done:
109
+ break
110
+ action = get_llm_action(obs)
111
+ error_str = None
112
+ try:
113
+ resp = env.step(action)
114
+ raw_reward = resp.reward.score
115
+ done = resp.done
116
+ obs = resp.observation
117
+ except Exception as e:
118
+ raw_reward = 0.1
119
+ done = False
120
+ error_str = str(e)[:50]
121
+
122
+ # Normalize reward strictly between 0 and 1
123
+ reward = max(0.01, min(0.99, (raw_reward + 1.0) / 2.0))
124
+ rewards.append(reward)
125
+ steps = step
126
+ log_step(step=step, action=action.action_type.value, reward=reward, done=done, error=error_str)
127
+ if done:
128
+ break
129
+
130
+ score = max(0.01, min(0.99, sum(rewards) / len(rewards))) if rewards else 0.5
131
+ success = score > 0.5
132
+
133
+ except Exception as e:
134
+ print(f"[DEBUG] Episode error: {e}", flush=True)
135
+ score = 0.5
136
+ success = False
137
+ finally:
138
+ safe_rewards = rewards if rewards else [0.5]
139
+ log_end(success=success, steps=steps, rewards=safe_rewards)
140
+
141
+ return {"task_id": task_id, "score": score, "steps": steps}
142
+
143
+ def main():
144
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
145
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
146
 
147
+ tasks = [
148
+ ("easy", "easy_001"),
149
+ ("medium", "medium_001"),
150
+ ("hard", "hard_001"),
151
+ ]
 
152
 
153
+ results = []
154
+ for difficulty, task_id in tasks:
155
+ result = run_episode(difficulty, task_id)
156
+ results.append(result)
157
 
158
+ avg = sum(r["score"] for r in results) / len(results)
159
+ print(f"\n[DEBUG] Average Score: {avg:.3f}", flush=True)
160
 
161
  if __name__ == "__main__":
162
  main()