Souravdanyal commited on
Commit
40ac3c8
Β·
verified Β·
1 Parent(s): 66d8c67

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -23
inference.py CHANGED
@@ -28,7 +28,8 @@ except ImportError:
28
 
29
  # ── Config ────────────────────────────────────────────────────────────────────
30
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
31
- MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
 
32
  HF_TOKEN = os.getenv("HF_TOKEN")
33
  HF_TOKEN_SOURCE = "HF_TOKEN"
34
  if not HF_TOKEN:
@@ -37,16 +38,18 @@ if not HF_TOKEN:
37
  if not HF_TOKEN:
38
  HF_TOKEN = os.getenv("hf_token")
39
  HF_TOKEN_SOURCE = "hf_token"
40
- # Optional when using from_docker_image():
41
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
42
- ENV_URL = os.getenv("ENV_URL")
43
- BENCHMARK = "code-debug-env"
44
- MAX_STEPS = 5
45
  SUCCESS_SCORE_THRESHOLD = 0.5
46
 
47
  client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
48
 
 
49
  # ── Logging β€” STRICT PLAINTEXT FORMAT ────────────────────────────────────────
 
50
  def _format_bool(value: bool) -> str:
51
  return "true" if value else "false"
52
 
@@ -72,6 +75,7 @@ def log_start(task_id: str, env: str, model: str) -> None:
72
  flush=True,
73
  )
74
 
 
75
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
76
  print(
77
  f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} "
@@ -79,6 +83,7 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
79
  flush=True,
80
  )
81
 
 
82
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
83
  print(
84
  f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} "
@@ -86,12 +91,15 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
86
  flush=True,
87
  )
88
 
 
89
  # ── Env client ────────────────────────────────────────────────────────────────
 
90
  def env_reset(url: str, difficulty: str) -> dict:
91
  r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
92
  r.raise_for_status()
93
  return r.json()
94
 
 
95
  def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict:
96
  payload = {"fixed_code": fixed_code}
97
  if explanation:
@@ -100,7 +108,9 @@ def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> di
100
  r.raise_for_status()
101
  return r.json()
102
 
 
103
  # ── LLM ──────────────────────────────────────────────────────────────────────
 
104
  SYSTEM_PROMPT = """You are an expert Python debugging agent.
105
 
106
  RESPONSE FORMAT β€” JSON only, no markdown fences, no extra text:
@@ -120,10 +130,11 @@ COMMON BUG PATTERNS β€” memorize these:
120
  - Wrong operator: target-n not target+n for complement
121
  - Off-by-one: lst[1] for second element not lst[2]
122
 
123
- IMPORTANT: If feedback shows TimeoutError β†’ you have infinite loop β†’ add visited set.
124
- IMPORTANT: If Expected shows right-rotated list β†’ use lst[-k:] + lst[:-k].
125
  """
126
 
 
127
  def _parse_llm_response(raw: str, buggy_code: str) -> dict:
128
  """Robustly parse LLM response handling control chars and malformed JSON."""
129
  # Remove markdown fences
@@ -174,7 +185,7 @@ def _parse_llm_response(raw: str, buggy_code: str) -> dict:
174
  exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None
175
  return {"fixed_code": code, "explanation": exp}
176
 
177
- # Complete fallback β€” return buggy code unchanged
178
  return {"fixed_code": buggy_code, "explanation": None}
179
 
180
 
@@ -228,6 +239,7 @@ def call_llm(
228
 
229
 
230
  # ── Episode ───────────────────────────────────────────────────────────────────
 
231
  def run_episode(env_url: str, difficulty: str) -> tuple:
232
  """Run one full episode. Returns (success, steps_taken, rewards)."""
233
  data = env_reset(env_url, difficulty)
@@ -251,19 +263,18 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
251
  code = action.get("fixed_code") or ""
252
  last_code = code
253
 
254
- reward = 0.0
255
- done = False
256
  step_error: Optional[str] = None
 
257
  try:
258
  result = env_step(env_url, code, action.get("explanation"))
259
  reward = result.get("reward", 0.0)
260
- done = result.get("done", False)
261
- obs_r = result.get("observation", {})
262
  if isinstance(obs_r, dict):
263
  last_feedback = obs_r.get("feedback", "")
264
- step_error = obs_r.get("last_action_error")
265
- if step_error is None:
266
- step_error = obs_r.get("error")
267
  except Exception as e:
268
  step_error = str(e)
269
 
@@ -274,10 +285,10 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
274
  success = True
275
  if done:
276
  break
 
277
  finally:
278
- # Compute normalized score for this episode and always emit [END].
279
- score = max(rewards) if rewards else 0.0
280
- score = min(max(score, 0.0), 1.0)
281
  success = success or (score >= SUCCESS_SCORE_THRESHOLD)
282
  log_end(success, steps_taken, score, rewards)
283
 
@@ -285,18 +296,21 @@ def run_episode(env_url: str, difficulty: str) -> tuple:
285
 
286
 
287
  # ── Main ──────────────────────────────────────────────────────────────────────
 
288
  def main():
289
  parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
290
  parser.add_argument("--url", default=ENV_URL or "http://localhost:7860")
291
- parser.add_argument("--difficulty", default=None, choices=["easy", "medium", "hard", "all"])
 
 
 
292
  args = parser.parse_args()
293
  url = args.url.rstrip("/")
294
 
295
  if not HF_TOKEN:
296
  print(
297
  "# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).",
298
- file=sys.stderr,
299
- flush=True,
300
  )
301
  sys.exit(1)
302
  print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True)
@@ -328,5 +342,4 @@ def main():
328
 
329
 
330
  if __name__ == "__main__":
331
- main()
332
-
 
28
 
29
  # ── Config ────────────────────────────────────────────────────────────────────
30
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
31
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
32
+
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  HF_TOKEN_SOURCE = "HF_TOKEN"
35
  if not HF_TOKEN:
 
38
  if not HF_TOKEN:
39
  HF_TOKEN = os.getenv("hf_token")
40
  HF_TOKEN_SOURCE = "hf_token"
41
+
42
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
43
+ ENV_URL = os.getenv("ENV_URL")
44
+ BENCHMARK = "code-debug-env"
45
+ MAX_STEPS = 5
46
  SUCCESS_SCORE_THRESHOLD = 0.5
47
 
48
  client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
49
 
50
+
51
  # ── Logging β€” STRICT PLAINTEXT FORMAT ────────────────────────────────────────
52
+
53
  def _format_bool(value: bool) -> str:
54
  return "true" if value else "false"
55
 
 
75
  flush=True,
76
  )
77
 
78
+
79
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
80
  print(
81
  f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} "
 
83
  flush=True,
84
  )
85
 
86
+
87
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
88
  print(
89
  f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} "
 
91
  flush=True,
92
  )
93
 
94
+
95
  # ── Env client ────────────────────────────────────────────────────────────────
96
+
97
  def env_reset(url: str, difficulty: str) -> dict:
98
  r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
99
  r.raise_for_status()
100
  return r.json()
101
 
102
+
103
  def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict:
104
  payload = {"fixed_code": fixed_code}
105
  if explanation:
 
108
  r.raise_for_status()
109
  return r.json()
110
 
111
+
112
  # ── LLM ──────────────────────────────────────────────────────────────────────
113
+
114
  SYSTEM_PROMPT = """You are an expert Python debugging agent.
115
 
116
  RESPONSE FORMAT β€” JSON only, no markdown fences, no extra text:
 
130
  - Wrong operator: target-n not target+n for complement
131
  - Off-by-one: lst[1] for second element not lst[2]
132
 
133
+ IMPORTANT: If feedback shows TimeoutError, you have infinite loop. Add visited set.
134
+ IMPORTANT: If Expected shows right-rotated list, use lst[-k:] + lst[:-k].
135
  """
136
 
137
+
138
  def _parse_llm_response(raw: str, buggy_code: str) -> dict:
139
  """Robustly parse LLM response handling control chars and malformed JSON."""
140
  # Remove markdown fences
 
185
  exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None
186
  return {"fixed_code": code, "explanation": exp}
187
 
188
+ # Complete fallback
189
  return {"fixed_code": buggy_code, "explanation": None}
190
 
191
 
 
239
 
240
 
241
  # ── Episode ───────────────────────────────────────────────────────────────────
242
+
243
  def run_episode(env_url: str, difficulty: str) -> tuple:
244
  """Run one full episode. Returns (success, steps_taken, rewards)."""
245
  data = env_reset(env_url, difficulty)
 
263
  code = action.get("fixed_code") or ""
264
  last_code = code
265
 
266
+ reward: float = 0.0
267
+ done: bool = False
268
  step_error: Optional[str] = None
269
+
270
  try:
271
  result = env_step(env_url, code, action.get("explanation"))
272
  reward = result.get("reward", 0.0)
273
+ done = result.get("done", False)
274
+ obs_r = result.get("observation", {})
275
  if isinstance(obs_r, dict):
276
  last_feedback = obs_r.get("feedback", "")
277
+ step_error = obs_r.get("last_action_error") or obs_r.get("error")
 
 
278
  except Exception as e:
279
  step_error = str(e)
280
 
 
285
  success = True
286
  if done:
287
  break
288
+
289
  finally:
290
+ score = max(rewards) if rewards else 0.0
291
+ score = min(max(score, 0.0), 1.0)
 
292
  success = success or (score >= SUCCESS_SCORE_THRESHOLD)
293
  log_end(success, steps_taken, score, rewards)
294
 
 
296
 
297
 
298
  # ── Main ──────────────────────────────────────────────────────────────────────
299
+
300
  def main():
301
  parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
302
  parser.add_argument("--url", default=ENV_URL or "http://localhost:7860")
303
+ parser.add_argument(
304
+ "--difficulty", default=None,
305
+ choices=["easy", "medium", "hard", "all"],
306
+ )
307
  args = parser.parse_args()
308
  url = args.url.rstrip("/")
309
 
310
  if not HF_TOKEN:
311
  print(
312
  "# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).",
313
+ file=sys.stderr, flush=True,
 
314
  )
315
  sys.exit(1)
316
  print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True)
 
342
 
343
 
344
  if __name__ == "__main__":
345
+ main()