Souravdanyal commited on
Commit
b464fa4
Β·
1 Parent(s): 75234fe

Improve LLM prompting and feedback handling

Browse files
Files changed (1) hide show
  1. inference.py +97 -32
inference.py CHANGED
@@ -29,10 +29,6 @@ MAX_STEPS = 5
29
  client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
30
 
31
  # ─── Logging β€” STRICT FORMAT REQUIRED BY EVALUATOR ───────────────────────────
32
- # [START] task=<task_id> env=<benchmark> model=<model_name>
33
- # [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
34
- # [END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
35
-
36
  def log_start(task_id: str, env: str, model: str) -> None:
37
  print(f"[START] task={task_id} env={env} model={model}", flush=True)
38
 
@@ -60,38 +56,64 @@ def env_step(env_url: str, fixed_code: str, explanation: str = None) -> dict:
60
  return resp.json()
61
 
62
  # ─── LLM Agent ───────────────────────────────────────────────────────────────
63
- SYSTEM_PROMPT = """You are an expert Python debugging agent.
64
- You will be given buggy Python code and must fix it.
65
 
66
- For easy tasks: fix the single bug.
67
- For medium tasks: fix both bugs.
68
- For hard tasks: fix the algorithmic bug AND explain your reasoning in the 'explanation' field.
 
 
69
 
70
- You MUST respond ONLY with valid JSON in this exact format:
71
  {
72
- "fixed_code": "<complete fixed Python function as a string>",
73
- "explanation": "<required for hard tasks; describe what was wrong and why your fix is correct>"
74
  }
75
 
76
- Rules:
77
- - Return the COMPLETE function, not just the changed line.
78
- - The fixed_code must be valid Python that can be exec'd.
79
- - For hard tasks, explanation must discuss the algorithm, root cause, and fix.
80
- - Do NOT include markdown fences or any text outside the JSON object.
 
 
81
  """
82
 
83
  def call_llm(buggy_code: str, instructions: str, difficulty: str,
84
- feedback: str = None, attempt: int = 1) -> dict:
 
 
85
  user_content = f"""Task difficulty: {difficulty}
86
  Instructions: {instructions}
87
 
88
- Buggy code:
89
  ```python
90
  {buggy_code}
91
  ```
92
  """
93
  if feedback and attempt > 1:
94
- user_content += f"\nPrevious attempt feedback:\n{feedback}\n\nPlease fix the remaining issues."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  messages = [
97
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -100,15 +122,39 @@ Buggy code:
100
 
101
  try:
102
  response = client.chat.completions.create(
103
- model=MODEL_NAME, messages=messages, max_tokens=1000, temperature=0.1,
 
 
 
104
  )
105
  content = response.choices[0].message.content.strip()
 
 
106
  if content.startswith("```"):
107
  lines = content.split("\n")
108
- content = "\n".join(lines[1:-1]) if lines[-1] == "```" else "\n".join(lines[1:])
 
 
 
 
109
  parsed = json.loads(content)
110
- return {"fixed_code": parsed.get("fixed_code", ""), "explanation": parsed.get("explanation", None)}
 
 
 
111
  except json.JSONDecodeError:
 
 
 
 
 
 
 
 
 
 
 
 
112
  return {"fixed_code": buggy_code, "explanation": None}
113
  except Exception as e:
114
  print(f"# LLM call failed: {e}", file=sys.stderr)
@@ -125,27 +171,38 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
125
  log_start(task_id=task_id, env=BENCHMARK, model=MODEL_NAME)
126
 
127
  last_feedback = None
 
128
  rewards: List[float] = []
129
  steps_taken = 0
130
  success = False
131
 
132
  for attempt in range(1, MAX_STEPS + 1):
133
  steps_taken = attempt
 
134
  agent_action = call_llm(
135
- buggy_code=buggy_code, instructions=instructions,
136
- difficulty=difficulty, feedback=last_feedback, attempt=attempt,
 
 
 
 
137
  )
 
138
  fixed_code = agent_action["fixed_code"]
 
139
 
140
  if not fixed_code or not fixed_code.strip():
141
- log_step(step=attempt, action="empty_submission", reward=0.0, done=False, error="empty_code")
 
142
  rewards.append(0.0)
143
  continue
144
 
145
  try:
146
- result = env_step(env_url, fixed_code=fixed_code, explanation=agent_action.get("explanation"))
 
147
  except Exception as e:
148
- log_step(step=attempt, action="step_failed", reward=0.0, done=False, error=str(e)[:60])
 
149
  rewards.append(0.0)
150
  continue
151
 
@@ -154,11 +211,13 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
154
  obs_r = result.get("observation", {})
155
  last_feedback = obs_r.get("feedback", "")
156
 
157
- log_step(step=attempt, action=f"fix_{difficulty}_attempt{attempt}", reward=reward, done=done, error=None)
 
158
  rewards.append(reward)
159
 
160
  if reward >= 1.0:
161
  success = True
 
162
  if done:
163
  break
164
 
@@ -168,7 +227,8 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
168
  def main():
169
  parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
170
  parser.add_argument("--url", default=ENV_URL, help="Environment base URL")
171
- parser.add_argument("--difficulty", default=None, choices=["easy", "medium", "hard", "all"])
 
172
  args = parser.parse_args()
173
  env_url = args.url.rstrip("/")
174
 
@@ -180,10 +240,12 @@ def main():
180
  print(f"# Health check failed: {e}", file=sys.stderr)
181
  sys.exit(1)
182
 
183
- difficulties = ["easy", "medium", "hard"] if (args.difficulty in ("all", None)) else [args.difficulty]
 
184
 
185
  all_rewards = []
186
  all_successes = []
 
187
  for difficulty in difficulties:
188
  success, steps, rewards = run_episode(env_url, difficulty)
189
  all_rewards.extend(rewards)
@@ -191,7 +253,10 @@ def main():
191
  time.sleep(0.5)
192
 
193
  avg = round(sum(all_rewards) / len(all_rewards), 3) if all_rewards else 0.0
194
- print(f"# SUMMARY: {sum(all_successes)}/{len(difficulties)} tasks solved | avg_reward={avg}", flush=True)
 
 
 
195
 
196
  if __name__ == "__main__":
197
  main()
 
29
  client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
30
 
31
  # ─── Logging β€” STRICT FORMAT REQUIRED BY EVALUATOR ───────────────────────────
 
 
 
 
32
  def log_start(task_id: str, env: str, model: str) -> None:
33
  print(f"[START] task={task_id} env={env} model={model}", flush=True)
34
 
 
56
  return resp.json()
57
 
58
  # ─── LLM Agent ───────────────────────────────────────────────────────────────
59
+ SYSTEM_PROMPT = """You are an expert Python debugging agent. Your job is to find and fix bugs in Python functions.
 
60
 
61
+ CRITICAL RULES:
62
+ - You MUST respond ONLY with valid JSON β€” no markdown, no explanation outside JSON
63
+ - Return the COMPLETE fixed function, not just the changed line
64
+ - The fixed_code must be syntactically valid Python
65
+ - For hard tasks, the explanation field MUST describe: what the bug was, why it caused failures, and how your fix resolves it
66
 
67
+ Response format (strictly):
68
  {
69
+ "fixed_code": "<complete corrected Python function>",
70
+ "explanation": "<for hard tasks: detailed explanation of bug and fix>"
71
  }
72
 
73
+ DEBUGGING STRATEGY:
74
+ 1. Read the instructions carefully β€” they tell you exactly what type of bug exists
75
+ 2. Trace through the logic with the test inputs mentally
76
+ 3. For easy tasks: find the ONE wrong operator, value, or return statement
77
+ 4. For medium tasks: find BOTH bugs β€” usually one logic bug + one edge case
78
+ 5. For hard tasks: find the algorithmic flaw + write a clear explanation
79
+ 6. If your previous attempt failed, READ THE FEEDBACK β€” it shows exactly which inputs failed and what output was expected
80
  """
81
 
82
  def call_llm(buggy_code: str, instructions: str, difficulty: str,
83
+ feedback: str = None, attempt: int = 1,
84
+ previous_code: str = None) -> dict:
85
+
86
  user_content = f"""Task difficulty: {difficulty}
87
  Instructions: {instructions}
88
 
89
+ Buggy code to fix:
90
  ```python
91
  {buggy_code}
92
  ```
93
  """
94
  if feedback and attempt > 1:
95
+ user_content += f"""
96
+ PREVIOUS ATTEMPT FAILED. Here is the feedback showing what went wrong:
97
+ {feedback}
98
+
99
+ Your previous fix was:
100
+ ```python
101
+ {previous_code or 'unknown'}
102
+ ```
103
+
104
+ IMPORTANT: Your previous fix did not work. Carefully analyze the feedback above.
105
+ Look at the Input, Expected, and Got values for each failing test.
106
+ Try a completely different approach to fix the bug.
107
+ """
108
+
109
+ if difficulty == "hard":
110
+ user_content += """
111
+ Remember: For hard tasks you MUST include a detailed explanation field describing:
112
+ - What the algorithmic bug was
113
+ - Why it caused incorrect results
114
+ - How your fix resolves it
115
+ Explanation quality affects 30% of your reward.
116
+ """
117
 
118
  messages = [
119
  {"role": "system", "content": SYSTEM_PROMPT},
 
122
 
123
  try:
124
  response = client.chat.completions.create(
125
+ model=MODEL_NAME,
126
+ messages=messages,
127
+ max_tokens=1500,
128
+ temperature=0.2 if attempt == 1 else 0.5,
129
  )
130
  content = response.choices[0].message.content.strip()
131
+
132
+ # Strip markdown fences
133
  if content.startswith("```"):
134
  lines = content.split("\n")
135
+ content = "\n".join(lines[1:-1]) if lines[-1].strip() == "```" else "\n".join(lines[1:])
136
+ # Strip json prefix
137
+ if content.startswith("json"):
138
+ content = content[4:].strip()
139
+
140
  parsed = json.loads(content)
141
+ return {
142
+ "fixed_code": parsed.get("fixed_code", ""),
143
+ "explanation": parsed.get("explanation", None),
144
+ }
145
  except json.JSONDecodeError:
146
+ # Try to extract code from malformed response
147
+ if "def " in content:
148
+ lines = content.split("\n")
149
+ code_lines = []
150
+ in_code = False
151
+ for line in lines:
152
+ if line.strip().startswith("def "):
153
+ in_code = True
154
+ if in_code:
155
+ code_lines.append(line)
156
+ if code_lines:
157
+ return {"fixed_code": "\n".join(code_lines), "explanation": None}
158
  return {"fixed_code": buggy_code, "explanation": None}
159
  except Exception as e:
160
  print(f"# LLM call failed: {e}", file=sys.stderr)
 
171
  log_start(task_id=task_id, env=BENCHMARK, model=MODEL_NAME)
172
 
173
  last_feedback = None
174
+ last_fixed_code = None
175
  rewards: List[float] = []
176
  steps_taken = 0
177
  success = False
178
 
179
  for attempt in range(1, MAX_STEPS + 1):
180
  steps_taken = attempt
181
+
182
  agent_action = call_llm(
183
+ buggy_code=buggy_code,
184
+ instructions=instructions,
185
+ difficulty=difficulty,
186
+ feedback=last_feedback,
187
+ attempt=attempt,
188
+ previous_code=last_fixed_code,
189
  )
190
+
191
  fixed_code = agent_action["fixed_code"]
192
+ last_fixed_code = fixed_code
193
 
194
  if not fixed_code or not fixed_code.strip():
195
+ log_step(step=attempt, action="empty_submission",
196
+ reward=0.0, done=False, error="empty_code")
197
  rewards.append(0.0)
198
  continue
199
 
200
  try:
201
+ result = env_step(env_url, fixed_code=fixed_code,
202
+ explanation=agent_action.get("explanation"))
203
  except Exception as e:
204
+ log_step(step=attempt, action="step_failed",
205
+ reward=0.0, done=False, error=str(e)[:60])
206
  rewards.append(0.0)
207
  continue
208
 
 
211
  obs_r = result.get("observation", {})
212
  last_feedback = obs_r.get("feedback", "")
213
 
214
+ log_step(step=attempt, action=f"fix_{difficulty}_attempt{attempt}",
215
+ reward=reward, done=done, error=None)
216
  rewards.append(reward)
217
 
218
  if reward >= 1.0:
219
  success = True
220
+
221
  if done:
222
  break
223
 
 
227
  def main():
228
  parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
229
  parser.add_argument("--url", default=ENV_URL, help="Environment base URL")
230
+ parser.add_argument("--difficulty", default=None,
231
+ choices=["easy", "medium", "hard", "all"])
232
  args = parser.parse_args()
233
  env_url = args.url.rstrip("/")
234
 
 
240
  print(f"# Health check failed: {e}", file=sys.stderr)
241
  sys.exit(1)
242
 
243
+ difficulties = ["easy", "medium", "hard"] if (
244
+ args.difficulty in ("all", None)) else [args.difficulty]
245
 
246
  all_rewards = []
247
  all_successes = []
248
+
249
  for difficulty in difficulties:
250
  success, steps, rewards = run_episode(env_url, difficulty)
251
  all_rewards.extend(rewards)
 
253
  time.sleep(0.5)
254
 
255
  avg = round(sum(all_rewards) / len(all_rewards), 3) if all_rewards else 0.0
256
+ print(
257
+ f"# SUMMARY: {sum(all_successes)}/{len(difficulties)} tasks solved | avg_reward={avg}",
258
+ flush=True
259
+ )
260
 
261
  if __name__ == "__main__":
262
  main()