SimranShaikh commited on
Commit
7d77fa5
Β·
verified Β·
1 Parent(s): 319e242
Files changed (1) hide show
  1. inference.py +27 -44
inference.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  inference.py β€” Baseline inference script for CodeReview-Env.
3
-
4
- Mandatory [START] / [STEP] / [END] log format for OpenEnv evaluators.
5
 
6
  Environment variables:
7
  API_BASE_URL LLM API base URL
@@ -31,38 +30,26 @@ SUCCESS_SCORE_THRESHOLD = 0.6
31
  TASKS = ["easy_syntax", "medium_logic", "hard_security"]
32
 
33
 
34
- # ── Mandatory log format ──────────────────────────────────────
35
 
36
  def log_start(task: str, env: str, model: str) -> None:
37
- print(json.dumps({"type": "START", "task": task, "env": env, "model": model}), flush=True)
38
 
39
 
40
  def log_step(step: int, action: Any, reward: float, done: bool, error: Optional[str] = None) -> None:
41
- print(json.dumps({
42
- "type": "STEP",
43
- "step": step,
44
- "action": str(action)[:300],
45
- "reward": reward,
46
- "done": done,
47
- "error": error,
48
- }), flush=True)
49
 
50
 
51
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
52
- print(json.dumps({
53
- "type": "END",
54
- "success": success,
55
- "steps": steps,
56
- "score": score,
57
- "rewards": rewards,
58
- }), flush=True)
59
 
60
 
61
  # ── HTTP client for OpenEnv server ───────────────────────────
62
 
63
  class EnvClient:
64
- """Thin HTTP client for the OpenEnv-compliant Space API."""
65
-
66
  def __init__(self, base_url: str) -> None:
67
  self.base_url = base_url
68
  self._http = httpx.Client(timeout=60.0)
@@ -72,7 +59,7 @@ class EnvClient:
72
  try:
73
  r = self._http.get(f"{self.base_url}/health")
74
  if r.status_code == 200:
75
- print(f"[DEBUG] Server ready after {i} retries", flush=True)
76
  return True
77
  except Exception as e:
78
  print(f"[DEBUG] Waiting for server ({i+1}/{retries}): {e}", flush=True)
@@ -80,7 +67,6 @@ class EnvClient:
80
  return False
81
 
82
  def reset(self) -> Dict:
83
- """POST /reset β€” returns {observation, reward, done}"""
84
  try:
85
  r = self._http.post(f"{self.base_url}/reset")
86
  r.raise_for_status()
@@ -90,11 +76,9 @@ class EnvClient:
90
  return {"observation": {}, "reward": 0.0, "done": False}
91
 
92
  def step(self, action: Dict) -> Dict:
93
- """POST /step with {action: {...}} wrapper β€” OpenEnv format."""
94
  try:
95
- # OpenEnv create_app requires action wrapped: {"action": {...}}
96
- payload = {"action": action}
97
- r = self._http.post(f"{self.base_url}/step", json=payload)
98
  r.raise_for_status()
99
  return r.json()
100
  except Exception as e:
@@ -112,9 +96,8 @@ class EnvClient:
112
 
113
  SYSTEM_PROMPT = """\
114
  You are an expert software engineer specialising in code review, debugging, \
115
- and security auditing. You will be shown a code snippet with a task description.
116
-
117
- Return ONLY a JSON object in this exact format (no prose, no markdown fences):
118
 
119
  {
120
  "identified_issues": [
@@ -126,7 +109,7 @@ Return ONLY a JSON object in this exact format (no prose, no markdown fences):
126
  }
127
  ],
128
  "suggested_fix": "<complete corrected code as string, or null>",
129
- "explanation": "<brief summary of findings>",
130
  "submit": true
131
  }
132
  """
@@ -173,7 +156,6 @@ def call_llm(llm_client: OpenAI, prompt: str) -> str:
173
 
174
  def parse_llm_output(raw: str) -> Dict:
175
  raw = raw.strip()
176
- # Strip markdown fences if present
177
  if raw.startswith("```"):
178
  parts = raw.split("```")
179
  raw = parts[1] if len(parts) > 1 else raw
@@ -185,7 +167,7 @@ def parse_llm_output(raw: str) -> Dict:
185
  return {
186
  "identified_issues": [],
187
  "suggested_fix": None,
188
- "explanation": raw[:300],
189
  "submit": True,
190
  }
191
 
@@ -204,11 +186,10 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
204
  try:
205
  result = env.reset()
206
  obs = result.get("observation", {})
207
- max_steps = obs.get("max_steps", 5)
208
 
209
  for step in range(1, max_steps + 1):
210
- done = result.get("done", False)
211
- if done:
212
  break
213
 
214
  prompt = build_prompt(obs, step, prev_feedback)
@@ -226,7 +207,7 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
226
 
227
  log_step(
228
  step=step,
229
- action=action.get("explanation", "")[:200],
230
  reward=reward,
231
  done=done,
232
  error=None,
@@ -241,6 +222,7 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
241
 
242
  except Exception as e:
243
  print(f"[DEBUG] run_task error: {e}", flush=True)
 
244
  score = 0.0
245
  success = False
246
 
@@ -253,34 +235,35 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
253
  # ── Main ──────────────────────────────────────────────────────
254
 
255
  def main() -> None:
256
- print(f"[DEBUG] Starting inference. SPACE_URL={SPACE_URL} MODEL={MODEL_NAME}", flush=True)
257
 
258
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
259
  env = EnvClient(SPACE_URL)
260
 
261
  if not env.wait_until_ready():
262
  print("[ERROR] Server not reachable. Exiting.", flush=True)
 
 
 
263
  sys.exit(1)
264
 
265
  task_scores: Dict[str, float] = {}
266
 
267
  for task_id in TASKS:
268
- print(f"\n{'='*50}\nRunning: {task_id}\n{'='*50}", flush=True)
269
  try:
270
  task_scores[task_id] = run_task(task_id, env, llm)
271
  except Exception as e:
272
  print(f"[DEBUG] Task {task_id} crashed: {e}", flush=True)
 
 
273
  task_scores[task_id] = 0.0
274
  time.sleep(1)
275
 
276
  env.close()
277
 
278
- print(f"\n{'='*50}\nFINAL SCORES\n{'='*50}", flush=True)
279
- for tid, s in task_scores.items():
280
- status = "PASS" if s >= SUCCESS_SCORE_THRESHOLD else "FAIL"
281
- print(f" {tid:25s}: {s:.4f} [{status}]", flush=True)
282
  overall = sum(task_scores.values()) / len(task_scores)
283
- print(f"\n Overall: {overall:.4f}", flush=True)
284
 
285
 
286
  if __name__ == "__main__":
 
1
  """
2
  inference.py β€” Baseline inference script for CodeReview-Env.
3
+ Uses required [START] / [STEP] / [END] plain-text log format.
 
4
 
5
  Environment variables:
6
  API_BASE_URL LLM API base URL
 
30
  TASKS = ["easy_syntax", "medium_logic", "hard_security"]
31
 
32
 
33
+ # ── MANDATORY log format: plain text [START]/[STEP]/[END] ─────
34
 
35
  def log_start(task: str, env: str, model: str) -> None:
36
+ print(f"[START] task={task} env={env} model={model}", flush=True)
37
 
38
 
39
  def log_step(step: int, action: Any, reward: float, done: bool, error: Optional[str] = None) -> None:
40
+ action_str = str(action)[:100].replace("\n", " ")
41
+ error_str = error if error else "null"
42
+ print(f"[STEP] step={step} action={action_str} reward={reward} done={done} error={error_str}", flush=True)
 
 
 
 
 
43
 
44
 
45
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
46
+ rewards_str = str([round(r, 4) for r in rewards])
47
+ print(f"[END] success={success} steps={steps} score={score} rewards={rewards_str}", flush=True)
 
 
 
 
 
48
 
49
 
50
  # ── HTTP client for OpenEnv server ───────────────────────────
51
 
52
  class EnvClient:
 
 
53
  def __init__(self, base_url: str) -> None:
54
  self.base_url = base_url
55
  self._http = httpx.Client(timeout=60.0)
 
59
  try:
60
  r = self._http.get(f"{self.base_url}/health")
61
  if r.status_code == 200:
62
+ print(f"[DEBUG] Server ready", flush=True)
63
  return True
64
  except Exception as e:
65
  print(f"[DEBUG] Waiting for server ({i+1}/{retries}): {e}", flush=True)
 
67
  return False
68
 
69
  def reset(self) -> Dict:
 
70
  try:
71
  r = self._http.post(f"{self.base_url}/reset")
72
  r.raise_for_status()
 
76
  return {"observation": {}, "reward": 0.0, "done": False}
77
 
78
  def step(self, action: Dict) -> Dict:
 
79
  try:
80
+ # OpenEnv create_app requires: {"action": {...}}
81
+ r = self._http.post(f"{self.base_url}/step", json={"action": action})
 
82
  r.raise_for_status()
83
  return r.json()
84
  except Exception as e:
 
96
 
97
  SYSTEM_PROMPT = """\
98
  You are an expert software engineer specialising in code review, debugging, \
99
+ and security auditing. Analyse the code and return ONLY a JSON object \
100
+ (no prose, no markdown fences):
 
101
 
102
  {
103
  "identified_issues": [
 
109
  }
110
  ],
111
  "suggested_fix": "<complete corrected code as string, or null>",
112
+ "explanation": "<brief summary of all findings>",
113
  "submit": true
114
  }
115
  """
 
156
 
157
  def parse_llm_output(raw: str) -> Dict:
158
  raw = raw.strip()
 
159
  if raw.startswith("```"):
160
  parts = raw.split("```")
161
  raw = parts[1] if len(parts) > 1 else raw
 
167
  return {
168
  "identified_issues": [],
169
  "suggested_fix": None,
170
+ "explanation": raw[:200],
171
  "submit": True,
172
  }
173
 
 
186
  try:
187
  result = env.reset()
188
  obs = result.get("observation", {})
189
+ max_steps = int(obs.get("max_steps", 5))
190
 
191
  for step in range(1, max_steps + 1):
192
+ if result.get("done", False):
 
193
  break
194
 
195
  prompt = build_prompt(obs, step, prev_feedback)
 
207
 
208
  log_step(
209
  step=step,
210
+ action=action.get("explanation", "")[:100],
211
  reward=reward,
212
  done=done,
213
  error=None,
 
222
 
223
  except Exception as e:
224
  print(f"[DEBUG] run_task error: {e}", flush=True)
225
+ log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
226
  score = 0.0
227
  success = False
228
 
 
235
  # ── Main ──────────────────────────────────────────────────────
236
 
237
  def main() -> None:
238
+ print(f"[DEBUG] Starting. SPACE_URL={SPACE_URL} MODEL={MODEL_NAME}", flush=True)
239
 
240
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
241
  env = EnvClient(SPACE_URL)
242
 
243
  if not env.wait_until_ready():
244
  print("[ERROR] Server not reachable. Exiting.", flush=True)
245
+ for task_id in TASKS:
246
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
247
+ log_end(success=False, steps=0, score=0.0, rewards=[])
248
  sys.exit(1)
249
 
250
  task_scores: Dict[str, float] = {}
251
 
252
  for task_id in TASKS:
253
+ print(f"\n[DEBUG] ===== Running task: {task_id} =====", flush=True)
254
  try:
255
  task_scores[task_id] = run_task(task_id, env, llm)
256
  except Exception as e:
257
  print(f"[DEBUG] Task {task_id} crashed: {e}", flush=True)
258
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
259
+ log_end(success=False, steps=0, score=0.0, rewards=[])
260
  task_scores[task_id] = 0.0
261
  time.sleep(1)
262
 
263
  env.close()
264
 
 
 
 
 
265
  overall = sum(task_scores.values()) / len(task_scores)
266
+ print(f"\n[DEBUG] Overall average: {overall:.4f}", flush=True)
267
 
268
 
269
  if __name__ == "__main__":