codemaverick2 commited on
Commit
6535d0b
Β·
1 Parent(s): d3e536b

Add line numbers to code display, fix clear_flag loop, match submission spec exactly

Browse files
Files changed (1) hide show
  1. inference.py +158 -179
inference.py CHANGED
@@ -1,11 +1,14 @@
1
  """
2
- Inference script for the Code Review Environment.
 
 
 
 
 
3
 
4
- Environment variables (MANDATORY):
5
- API_BASE_URL β€” LLM API endpoint (default: https://router.huggingface.co/v1)
6
- MODEL_NAME β€” Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
7
- HF_TOKEN β€” Your HuggingFace / API key (no default β€” must be set)
8
- ENV_URL β€” Environment base URL (default: http://localhost:7860)
9
 
10
  Usage:
11
  export HF_TOKEN=hf_...
@@ -17,13 +20,14 @@ import os
17
  import sys
18
  import json
19
  import time
20
- from typing import Optional
21
 
22
  import httpx
 
23
 
24
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
25
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
26
- HF_TOKEN = os.getenv("HF_TOKEN")
27
  ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/")
28
  BENCHMARK = "code-review-env"
29
 
@@ -45,9 +49,9 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
45
  )
46
 
47
 
48
- def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
49
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
50
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
51
 
52
  # Curriculum ordering: easy β†’ medium β†’ medium-hard β†’ hard
53
  # Research (CAMRL, Curriculum RL): start with simpler tasks to build
@@ -115,17 +119,16 @@ When finished, respond with:
115
  ## RULES
116
  - Raw JSON only β€” no markdown fences, no extra text
117
  - One action per response
118
- - Count lines carefully from line 1 (including blank lines and comments)
119
  - Only flag REAL issues β€” no style preferences, no hypothetical issues
120
  - Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
121
  - Flag the EXACT line where the problem code is (the f-string line, not the function def)
 
122
  """
123
 
124
 
125
  def chat_completion(messages: list) -> str:
126
- from openai import OpenAI
127
-
128
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
129
  try:
130
  response = client.chat.completions.create(
131
  model=MODEL_NAME,
@@ -258,26 +261,30 @@ def _should_submit(obs: dict, step_count: int, max_steps: int) -> bool:
258
  return False
259
 
260
 
 
 
 
261
  def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
262
  """
263
  Recovery strategy: if the last flag was a false positive with high penalty,
264
- suggest clearing it to recover partial reward and improve precision.
265
-
266
- Returns a clear_flag action dict if we should recover, else None.
267
  """
268
  if last_reward is None or last_reward >= 0:
269
  return None
270
  if last_action.get("action_type") != "flag_issue":
271
  return None
272
 
273
- # Only clear if it was a clear FP (no near-miss indicator in feedback)
274
- # and we've got too many false positives
 
 
 
275
  progress = obs.get("progress", {})
276
  fp = int(progress.get("false_positives", 0))
277
  tp = int(progress.get("true_positives", 0))
278
 
279
- # If FP > TP and last reward was notably negative, clear the bad flag
280
  if fp > tp and last_reward <= -0.05:
 
281
  return {
282
  "action_type": "clear_flag",
283
  "line_number": last_action.get("line_number"),
@@ -289,7 +296,8 @@ def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Opti
289
 
290
  def run_task(task_id: str, http_client: httpx.Client) -> dict:
291
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
292
- all_rewards: list = []
 
293
  step_count = 0
294
  final_score = 0.0
295
 
@@ -297,163 +305,141 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
297
  resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
298
  resp.raise_for_status()
299
  obs = resp.json()
300
- except Exception as e:
301
- print(f"[DEBUG] Reset failed: {e}", flush=True)
302
- log_end(success=False, steps=0, score=0.0, rewards=[])
303
- return run_keyword_fallback(ENV_URL, task_id)
304
 
305
- code_display = "\n\n".join(
306
- f"=== {fname} (starting at line 1) ===\n{code}"
307
- for fname, code in obs.get("code_files", {}).items()
308
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
- # Include function map hint if available
311
- code_metadata = obs.get("code_metadata") or {}
312
- function_ranges = code_metadata.get("function_ranges", [])
313
- fn_map_hint = ""
314
- if function_ranges:
315
- fn_lines = [f" {fr['name']}() in {fr['file']} (lines {fr['start']}-{fr['end']})"
316
- for fr in function_ranges]
317
- fn_map_hint = "\n\nFunction map:\n" + "\n".join(fn_lines)
318
-
319
- task_desc = obs.get("task_description", "")
320
- max_steps = obs.get("max_steps", 20)
321
- issue_categories = code_metadata.get("issue_categories", [])
322
- n_gt = len(obs.get("code_files", {})) # rough complexity hint
323
- category_hint = ""
324
- if issue_categories:
325
- category_hint = f"\nIssue categories to look for: {sorted(set(issue_categories))}"
326
-
327
- # RC-GRPO style reward conditioning (2025): tell the agent what quality level
328
- # it should aim for, so it calibrates confidence appropriately.
329
- state_features = code_metadata.get("state_features", [])
330
- complexity_label = "medium"
331
- if state_features and len(state_features) >= 4:
332
- complexity_score = state_features[3]
333
- complexity_label = "high" if complexity_score >= 1.0 else "medium" if complexity_score >= 0.5 else "low"
334
-
335
- reward_conditioning = (
336
- f"[TARGET: high-quality review, score β‰₯ 0.85. "
337
- f"Code complexity: {complexity_label}. "
338
- f"Be thorough β€” missing issues costs more than a single FP.]"
339
- )
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- messages = [
342
- {"role": "system", "content": SYSTEM_PROMPT},
343
- {
344
- "role": "user",
345
- "content": (
346
- f"{reward_conditioning}\n\n"
347
- f"Task: {task_desc}\n\n"
348
- f"{code_display}"
349
- f"{fn_map_hint}"
350
- f"{category_hint}\n\n"
351
- f"You have {max_steps} steps total. "
352
- f"Work through the checklist systematically, function by function. "
353
- f"Flag each issue one at a time as a raw JSON object."
354
- ),
355
- },
356
- ]
357
-
358
- done = False
359
- last_action: dict = {}
360
- last_reward: Optional[float] = None
361
- consecutive_fp = 0
362
-
363
- while not done and step_count < max_steps:
364
- # --- Auto clear_flag recovery: undo recent FP if hurting precision ---
365
- recovery_action = _should_clear_flag(obs, last_reward, last_action)
366
- if recovery_action and step_count < max_steps - 1:
367
- action = recovery_action
368
- action_text = json.dumps(action)
369
- print(f" Auto-recovery: clearing FP at {action.get('filename')}:{action.get('line_number')}")
370
- else:
371
- # --- Normal LLM action ---
372
  try:
373
- action_text = chat_completion(messages)
 
 
374
  except Exception as e:
375
- print(f"[DEBUG] LLM unavailable ({e}) β€” submitting", flush=True)
376
- try:
377
- http_client.post(f"{ENV_URL}/step", json={"action_type": "submit_review"}, timeout=30)
378
- except Exception:
379
- pass
380
- log_end(success=False, steps=step_count, score=0.0, rewards=all_rewards)
381
- return {"task_id": task_id, "score": 0.0, "steps": step_count, "method": "error"}
382
-
383
- action = parse_action(action_text)
384
-
385
- # Smart submission: inject submit if progress shows we're done
386
- if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
387
- print(f" Smart submit at step {step_count + 1} (recall target met)")
388
- action = {"action_type": "submit_review"}
389
- action_text = json.dumps(action)
390
 
391
- try:
392
- step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
393
- step_resp.raise_for_status()
394
- obs = step_resp.json()
395
- except Exception as e:
396
  step_count += 1
397
- log_step(step=step_count, action="error", reward=0.0, done=True, error=str(e))
398
- break
399
-
400
- done = obs.get("done", False)
401
- step_count += 1
402
- last_reward = obs.get("reward")
403
- # Use terminal reward (final grade) when done, else intermediate score
404
- if done:
405
- final_score = last_reward or obs.get("current_score", 0.0)
406
- else:
407
- final_score = obs.get("current_score", 0.0)
408
- last_action = action
409
-
410
- # Track consecutive FPs for logging
411
- if last_reward is not None and last_reward < 0 and action.get("action_type") == "flag_issue":
412
- consecutive_fp += 1
413
- else:
414
- consecutive_fp = 0
415
-
416
- # Build rich feedback for next LLM turn
417
- progress_feedback = _build_progress_feedback(obs)
418
- env_feedback = obs.get("feedback", "")
419
- combined_feedback = env_feedback
420
- if progress_feedback:
421
- combined_feedback += f"\n{progress_feedback}"
422
-
423
- messages.append({"role": "assistant", "content": action_text})
424
- if combined_feedback:
425
- messages.append({"role": "user", "content": combined_feedback})
426
-
427
- # Context window management: keep system + initial prompt + last 12 exchanges
428
- # This prevents token limit errors on long episodes (25+ steps)
429
- max_history = 2 + 24 # system + initial user + 12 assistant/user pairs
430
- if len(messages) > max_history:
431
- messages = messages[:2] + messages[-(max_history - 2):]
432
-
433
- atype = action.get("action_type", "")
434
- reward_val = float(last_reward) if last_reward is not None else 0.0
435
- all_rewards.append(reward_val)
436
- action_str = f"{atype}({action.get('filename', '')}:{action.get('line_number', '')})" if atype == "flag_issue" else atype
437
- log_step(
438
- step=step_count,
439
- action=action_str,
440
- reward=reward_val,
441
- done=done,
442
- error=None,
443
- )
444
 
445
- if atype == "submit_review":
446
- final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
447
- break
448
 
449
- time.sleep(0.3)
 
 
 
 
 
 
 
 
450
 
451
- log_end(
452
- success=final_score >= 0.5,
453
- steps=step_count,
454
- score=final_score,
455
- rewards=all_rewards,
456
- )
457
  return {
458
  "task_id": task_id,
459
  "score": float(final_score),
@@ -463,21 +449,14 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
463
 
464
 
465
  def main():
466
- use_llm = bool(HF_TOKEN and API_BASE_URL)
467
-
468
- print("Code Review Environment β€” Inference")
469
- print(f" Model : {MODEL_NAME}")
470
- print(f" API URL : {API_BASE_URL or '(not set β€” using keyword heuristic)'}")
471
- print(f" Env URL : {ENV_URL}")
472
- print(f" Tasks : {TASK_IDS}\n")
473
 
474
  try:
475
  with httpx.Client(timeout=10) as probe:
476
  health = probe.get(f"{ENV_URL}/health")
477
  health.raise_for_status()
478
- print(f" Health: {health.json()}\n")
479
  except Exception as e:
480
- print(f"ERROR: Cannot reach environment at {ENV_URL}: {e}")
481
  sys.exit(1)
482
 
483
  results = {}
 
1
  """
2
+ Inference Script β€” Code Review Environment
3
+ ===========================================
4
+ MANDATORY environment variables:
5
+ API_BASE_URL The API endpoint for the LLM.
6
+ MODEL_NAME The model identifier to use for inference.
7
+ HF_TOKEN Your Hugging Face / API key.
8
 
9
+ Defaults are set only for API_BASE_URL and MODEL_NAME:
10
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
11
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
 
 
12
 
13
  Usage:
14
  export HF_TOKEN=hf_...
 
20
  import sys
21
  import json
22
  import time
23
+ from typing import List, Optional
24
 
25
  import httpx
26
+ from openai import OpenAI
27
 
28
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
30
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
31
  ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/")
32
  BENCHMARK = "code-review-env"
33
 
 
49
  )
50
 
51
 
52
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
53
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
54
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
55
 
56
  # Curriculum ordering: easy β†’ medium β†’ medium-hard β†’ hard
57
  # Research (CAMRL, Curriculum RL): start with simpler tasks to build
 
119
  ## RULES
120
  - Raw JSON only β€” no markdown fences, no extra text
121
  - One action per response
122
+ - Line numbers are shown as "N|" prefix β€” use those EXACT numbers, do NOT count yourself
123
  - Only flag REAL issues β€” no style preferences, no hypothetical issues
124
  - Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
125
  - Flag the EXACT line where the problem code is (the f-string line, not the function def)
126
+ - issue_type MUST be: "security" for injection/XSS/hardcoded secrets/crypto/auth, "bug" for logic/off-by-one/wrong values, "performance" for N+1/missing gather/uncapped pagination
127
  """
128
 
129
 
130
  def chat_completion(messages: list) -> str:
131
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
132
  try:
133
  response = client.chat.completions.create(
134
  model=MODEL_NAME,
 
261
  return False
262
 
263
 
264
+ _cleared_lines: set = set() # track lines we've already cleared to prevent loops
265
+
266
+
267
  def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
268
  """
269
  Recovery strategy: if the last flag was a false positive with high penalty,
270
+ suggest clearing it. Only clears each line ONCE to prevent flag/clear loops.
 
 
271
  """
272
  if last_reward is None or last_reward >= 0:
273
  return None
274
  if last_action.get("action_type") != "flag_issue":
275
  return None
276
 
277
+ # Prevent loop: never clear the same line twice
278
+ line_key = (last_action.get("filename"), last_action.get("line_number"))
279
+ if line_key in _cleared_lines:
280
+ return None
281
+
282
  progress = obs.get("progress", {})
283
  fp = int(progress.get("false_positives", 0))
284
  tp = int(progress.get("true_positives", 0))
285
 
 
286
  if fp > tp and last_reward <= -0.05:
287
+ _cleared_lines.add(line_key)
288
  return {
289
  "action_type": "clear_flag",
290
  "line_number": last_action.get("line_number"),
 
296
 
297
  def run_task(task_id: str, http_client: httpx.Client) -> dict:
298
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
299
+ _cleared_lines.clear() # reset per-task
300
+ all_rewards: List[float] = []
301
  step_count = 0
302
  final_score = 0.0
303
 
 
305
  resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
306
  resp.raise_for_status()
307
  obs = resp.json()
 
 
 
 
308
 
309
+ # Show code WITH line numbers β€” critical for LLM line-counting accuracy
310
+ code_parts = []
311
+ for fname, code in obs.get("code_files", {}).items():
312
+ numbered_lines = "\n".join(
313
+ f"{i+1:3d}| {line}" for i, line in enumerate(code.splitlines())
314
+ )
315
+ code_parts.append(f"=== {fname} ===\n{numbered_lines}")
316
+ code_display = "\n\n".join(code_parts)
317
+
318
+ code_metadata = obs.get("code_metadata") or {}
319
+ function_ranges = code_metadata.get("function_ranges", [])
320
+ fn_map_hint = ""
321
+ if function_ranges:
322
+ fn_lines = [f" {fr['name']}() in {fr['file']} (lines {fr['start']}-{fr['end']})"
323
+ for fr in function_ranges]
324
+ fn_map_hint = "\n\nFunction map:\n" + "\n".join(fn_lines)
325
+
326
+ task_desc = obs.get("task_description", "")
327
+ max_steps = obs.get("max_steps", 20)
328
+ issue_categories = code_metadata.get("issue_categories", [])
329
+ category_hint = ""
330
+ if issue_categories:
331
+ category_hint = f"\nIssue categories to look for: {sorted(set(issue_categories))}"
332
+
333
+ state_features = code_metadata.get("state_features", [])
334
+ complexity_label = "medium"
335
+ if state_features and len(state_features) >= 4:
336
+ complexity_score = state_features[3]
337
+ complexity_label = "high" if complexity_score >= 1.0 else "medium" if complexity_score >= 0.5 else "low"
338
+
339
+ reward_conditioning = (
340
+ f"[TARGET: high-quality review, score β‰₯ 0.85. "
341
+ f"Code complexity: {complexity_label}. "
342
+ f"Be thorough β€” missing issues costs more than a single FP.]"
343
+ )
344
 
345
+ messages = [
346
+ {"role": "system", "content": SYSTEM_PROMPT},
347
+ {
348
+ "role": "user",
349
+ "content": (
350
+ f"{reward_conditioning}\n\n"
351
+ f"Task: {task_desc}\n\n"
352
+ f"{code_display}"
353
+ f"{fn_map_hint}"
354
+ f"{category_hint}\n\n"
355
+ f"You have {max_steps} steps total. "
356
+ f"Work through the checklist systematically, function by function. "
357
+ f"Flag each issue one at a time as a raw JSON object."
358
+ ),
359
+ },
360
+ ]
361
+
362
+ done = False
363
+ last_action: dict = {}
364
+ last_reward: Optional[float] = None
365
+
366
+ while not done and step_count < max_steps:
367
+ recovery_action = _should_clear_flag(obs, last_reward, last_action)
368
+ if recovery_action and step_count < max_steps - 1:
369
+ action = recovery_action
370
+ action_text = json.dumps(action)
371
+ else:
372
+ try:
373
+ action_text = chat_completion(messages)
374
+ except Exception as e:
375
+ print(f"[DEBUG] LLM unavailable ({e})", flush=True)
376
+ try:
377
+ http_client.post(f"{ENV_URL}/step", json={"action_type": "submit_review"}, timeout=30)
378
+ except Exception:
379
+ pass
380
+ break
381
+
382
+ action = parse_action(action_text)
383
+
384
+ if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
385
+ action = {"action_type": "submit_review"}
386
+ action_text = json.dumps(action)
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  try:
389
+ step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
390
+ step_resp.raise_for_status()
391
+ obs = step_resp.json()
392
  except Exception as e:
393
+ step_count += 1
394
+ log_step(step=step_count, action="error", reward=0.0, done=True, error=str(e))
395
+ break
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ done = obs.get("done", False)
 
 
 
 
398
  step_count += 1
399
+ last_reward = obs.get("reward")
400
+ if done:
401
+ final_score = last_reward or obs.get("current_score", 0.0)
402
+ else:
403
+ final_score = obs.get("current_score", 0.0)
404
+ last_action = action
405
+
406
+ # Build feedback for next LLM turn
407
+ progress_feedback = _build_progress_feedback(obs)
408
+ env_feedback = obs.get("feedback", "")
409
+ combined_feedback = env_feedback
410
+ if progress_feedback:
411
+ combined_feedback += f"\n{progress_feedback}"
412
+
413
+ messages.append({"role": "assistant", "content": action_text})
414
+ if combined_feedback:
415
+ messages.append({"role": "user", "content": combined_feedback})
416
+
417
+ max_history = 2 + 24
418
+ if len(messages) > max_history:
419
+ messages = messages[:2] + messages[-(max_history - 2):]
420
+
421
+ atype = action.get("action_type", "")
422
+ reward_val = float(last_reward) if last_reward is not None else 0.0
423
+ all_rewards.append(reward_val)
424
+ action_str = f"{atype}({action.get('filename', '')}:{action.get('line_number', '')})" if atype == "flag_issue" else atype
425
+ log_step(step=step_count, action=action_str, reward=reward_val, done=done, error=None)
426
+
427
+ if atype == "submit_review":
428
+ final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
429
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+ time.sleep(0.3)
 
 
432
 
433
+ except Exception as e:
434
+ print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
435
+ finally:
436
+ log_end(
437
+ success=final_score >= 0.5,
438
+ steps=step_count,
439
+ score=final_score,
440
+ rewards=all_rewards,
441
+ )
442
 
 
 
 
 
 
 
443
  return {
444
  "task_id": task_id,
445
  "score": float(final_score),
 
449
 
450
 
451
  def main():
452
+ use_llm = bool(API_KEY and API_BASE_URL)
 
 
 
 
 
 
453
 
454
  try:
455
  with httpx.Client(timeout=10) as probe:
456
  health = probe.get(f"{ENV_URL}/health")
457
  health.raise_for_status()
 
458
  except Exception as e:
459
+ print(f"[DEBUG] Cannot reach environment at {ENV_URL}: {e}", flush=True)
460
  sys.exit(1)
461
 
462
  results = {}