Spaces:
Sleeping
Sleeping
codemaverick2 commited on
Commit Β·
6535d0b
1
Parent(s): d3e536b
Add line numbers to code display, fix clear_flag loop, match submission spec exactly
Browse files- inference.py +158 -179
inference.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
"""
|
| 2 |
-
Inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
API_BASE_URL
|
| 6 |
-
MODEL_NAME
|
| 7 |
-
HF_TOKEN β Your HuggingFace / API key (no default β must be set)
|
| 8 |
-
ENV_URL β Environment base URL (default: http://localhost:7860)
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
export HF_TOKEN=hf_...
|
|
@@ -17,13 +20,14 @@ import os
|
|
| 17 |
import sys
|
| 18 |
import json
|
| 19 |
import time
|
| 20 |
-
from typing import Optional
|
| 21 |
|
| 22 |
import httpx
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 28 |
BENCHMARK = "code-review-env"
|
| 29 |
|
|
@@ -45,9 +49,9 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
-
def log_end(success: bool, steps: int, score: float, rewards:
|
| 49 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 50 |
-
print(f"[END] success={str(success).lower()} steps={steps} score={score:.
|
| 51 |
|
| 52 |
# Curriculum ordering: easy β medium β medium-hard β hard
|
| 53 |
# Research (CAMRL, Curriculum RL): start with simpler tasks to build
|
|
@@ -115,17 +119,16 @@ When finished, respond with:
|
|
| 115 |
## RULES
|
| 116 |
- Raw JSON only β no markdown fences, no extra text
|
| 117 |
- One action per response
|
| 118 |
-
-
|
| 119 |
- Only flag REAL issues β no style preferences, no hypothetical issues
|
| 120 |
- Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
|
| 121 |
- Flag the EXACT line where the problem code is (the f-string line, not the function def)
|
|
|
|
| 122 |
"""
|
| 123 |
|
| 124 |
|
| 125 |
def chat_completion(messages: list) -> str:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 129 |
try:
|
| 130 |
response = client.chat.completions.create(
|
| 131 |
model=MODEL_NAME,
|
|
@@ -258,26 +261,30 @@ def _should_submit(obs: dict, step_count: int, max_steps: int) -> bool:
|
|
| 258 |
return False
|
| 259 |
|
| 260 |
|
|
|
|
|
|
|
|
|
|
| 261 |
def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
|
| 262 |
"""
|
| 263 |
Recovery strategy: if the last flag was a false positive with high penalty,
|
| 264 |
-
suggest clearing it
|
| 265 |
-
|
| 266 |
-
Returns a clear_flag action dict if we should recover, else None.
|
| 267 |
"""
|
| 268 |
if last_reward is None or last_reward >= 0:
|
| 269 |
return None
|
| 270 |
if last_action.get("action_type") != "flag_issue":
|
| 271 |
return None
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
| 275 |
progress = obs.get("progress", {})
|
| 276 |
fp = int(progress.get("false_positives", 0))
|
| 277 |
tp = int(progress.get("true_positives", 0))
|
| 278 |
|
| 279 |
-
# If FP > TP and last reward was notably negative, clear the bad flag
|
| 280 |
if fp > tp and last_reward <= -0.05:
|
|
|
|
| 281 |
return {
|
| 282 |
"action_type": "clear_flag",
|
| 283 |
"line_number": last_action.get("line_number"),
|
|
@@ -289,7 +296,8 @@ def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Opti
|
|
| 289 |
|
| 290 |
def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
| 291 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 292 |
-
|
|
|
|
| 293 |
step_count = 0
|
| 294 |
final_score = 0.0
|
| 295 |
|
|
@@ -297,163 +305,141 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
|
| 297 |
resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 298 |
resp.raise_for_status()
|
| 299 |
obs = resp.json()
|
| 300 |
-
except Exception as e:
|
| 301 |
-
print(f"[DEBUG] Reset failed: {e}", flush=True)
|
| 302 |
-
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 303 |
-
return run_keyword_fallback(ENV_URL, task_id)
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
for fname, code in obs.get("code_files", {}).items()
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
-
messages = [
|
| 342 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 343 |
-
{
|
| 344 |
-
"role": "user",
|
| 345 |
-
"content": (
|
| 346 |
-
f"{reward_conditioning}\n\n"
|
| 347 |
-
f"Task: {task_desc}\n\n"
|
| 348 |
-
f"{code_display}"
|
| 349 |
-
f"{fn_map_hint}"
|
| 350 |
-
f"{category_hint}\n\n"
|
| 351 |
-
f"You have {max_steps} steps total. "
|
| 352 |
-
f"Work through the checklist systematically, function by function. "
|
| 353 |
-
f"Flag each issue one at a time as a raw JSON object."
|
| 354 |
-
),
|
| 355 |
-
},
|
| 356 |
-
]
|
| 357 |
-
|
| 358 |
-
done = False
|
| 359 |
-
last_action: dict = {}
|
| 360 |
-
last_reward: Optional[float] = None
|
| 361 |
-
consecutive_fp = 0
|
| 362 |
-
|
| 363 |
-
while not done and step_count < max_steps:
|
| 364 |
-
# --- Auto clear_flag recovery: undo recent FP if hurting precision ---
|
| 365 |
-
recovery_action = _should_clear_flag(obs, last_reward, last_action)
|
| 366 |
-
if recovery_action and step_count < max_steps - 1:
|
| 367 |
-
action = recovery_action
|
| 368 |
-
action_text = json.dumps(action)
|
| 369 |
-
print(f" Auto-recovery: clearing FP at {action.get('filename')}:{action.get('line_number')}")
|
| 370 |
-
else:
|
| 371 |
-
# --- Normal LLM action ---
|
| 372 |
try:
|
| 373 |
-
|
|
|
|
|
|
|
| 374 |
except Exception as e:
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
except Exception:
|
| 379 |
-
pass
|
| 380 |
-
log_end(success=False, steps=step_count, score=0.0, rewards=all_rewards)
|
| 381 |
-
return {"task_id": task_id, "score": 0.0, "steps": step_count, "method": "error"}
|
| 382 |
-
|
| 383 |
-
action = parse_action(action_text)
|
| 384 |
-
|
| 385 |
-
# Smart submission: inject submit if progress shows we're done
|
| 386 |
-
if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
|
| 387 |
-
print(f" Smart submit at step {step_count + 1} (recall target met)")
|
| 388 |
-
action = {"action_type": "submit_review"}
|
| 389 |
-
action_text = json.dumps(action)
|
| 390 |
|
| 391 |
-
|
| 392 |
-
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 393 |
-
step_resp.raise_for_status()
|
| 394 |
-
obs = step_resp.json()
|
| 395 |
-
except Exception as e:
|
| 396 |
step_count += 1
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
# This prevents token limit errors on long episodes (25+ steps)
|
| 429 |
-
max_history = 2 + 24 # system + initial user + 12 assistant/user pairs
|
| 430 |
-
if len(messages) > max_history:
|
| 431 |
-
messages = messages[:2] + messages[-(max_history - 2):]
|
| 432 |
-
|
| 433 |
-
atype = action.get("action_type", "")
|
| 434 |
-
reward_val = float(last_reward) if last_reward is not None else 0.0
|
| 435 |
-
all_rewards.append(reward_val)
|
| 436 |
-
action_str = f"{atype}({action.get('filename', '')}:{action.get('line_number', '')})" if atype == "flag_issue" else atype
|
| 437 |
-
log_step(
|
| 438 |
-
step=step_count,
|
| 439 |
-
action=action_str,
|
| 440 |
-
reward=reward_val,
|
| 441 |
-
done=done,
|
| 442 |
-
error=None,
|
| 443 |
-
)
|
| 444 |
|
| 445 |
-
|
| 446 |
-
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
|
| 447 |
-
break
|
| 448 |
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
log_end(
|
| 452 |
-
success=final_score >= 0.5,
|
| 453 |
-
steps=step_count,
|
| 454 |
-
score=final_score,
|
| 455 |
-
rewards=all_rewards,
|
| 456 |
-
)
|
| 457 |
return {
|
| 458 |
"task_id": task_id,
|
| 459 |
"score": float(final_score),
|
|
@@ -463,21 +449,14 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
|
| 463 |
|
| 464 |
|
| 465 |
def main():
|
| 466 |
-
use_llm = bool(
|
| 467 |
-
|
| 468 |
-
print("Code Review Environment β Inference")
|
| 469 |
-
print(f" Model : {MODEL_NAME}")
|
| 470 |
-
print(f" API URL : {API_BASE_URL or '(not set β using keyword heuristic)'}")
|
| 471 |
-
print(f" Env URL : {ENV_URL}")
|
| 472 |
-
print(f" Tasks : {TASK_IDS}\n")
|
| 473 |
|
| 474 |
try:
|
| 475 |
with httpx.Client(timeout=10) as probe:
|
| 476 |
health = probe.get(f"{ENV_URL}/health")
|
| 477 |
health.raise_for_status()
|
| 478 |
-
print(f" Health: {health.json()}\n")
|
| 479 |
except Exception as e:
|
| 480 |
-
print(f"
|
| 481 |
sys.exit(1)
|
| 482 |
|
| 483 |
results = {}
|
|
|
|
| 1 |
"""
|
| 2 |
+
Inference Script β Code Review Environment
|
| 3 |
+
===========================================
|
| 4 |
+
MANDATORY environment variables:
|
| 5 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 6 |
+
MODEL_NAME The model identifier to use for inference.
|
| 7 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 8 |
|
| 9 |
+
Defaults are set only for API_BASE_URL and MODEL_NAME:
|
| 10 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 11 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
|
|
|
|
|
|
| 12 |
|
| 13 |
Usage:
|
| 14 |
export HF_TOKEN=hf_...
|
|
|
|
| 20 |
import sys
|
| 21 |
import json
|
| 22 |
import time
|
| 23 |
+
from typing import List, Optional
|
| 24 |
|
| 25 |
import httpx
|
| 26 |
+
from openai import OpenAI
|
| 27 |
|
| 28 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 29 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 30 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 31 |
ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 32 |
BENCHMARK = "code-review-env"
|
| 33 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
|
| 52 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 53 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 54 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 55 |
|
| 56 |
# Curriculum ordering: easy β medium β medium-hard β hard
|
| 57 |
# Research (CAMRL, Curriculum RL): start with simpler tasks to build
|
|
|
|
| 119 |
## RULES
|
| 120 |
- Raw JSON only β no markdown fences, no extra text
|
| 121 |
- One action per response
|
| 122 |
+
- Line numbers are shown as "N|" prefix β use those EXACT numbers, do NOT count yourself
|
| 123 |
- Only flag REAL issues β no style preferences, no hypothetical issues
|
| 124 |
- Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
|
| 125 |
- Flag the EXACT line where the problem code is (the f-string line, not the function def)
|
| 126 |
+
- issue_type MUST be: "security" for injection/XSS/hardcoded secrets/crypto/auth, "bug" for logic/off-by-one/wrong values, "performance" for N+1/missing gather/uncapped pagination
|
| 127 |
"""
|
| 128 |
|
| 129 |
|
| 130 |
def chat_completion(messages: list) -> str:
|
| 131 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
|
|
|
|
|
| 132 |
try:
|
| 133 |
response = client.chat.completions.create(
|
| 134 |
model=MODEL_NAME,
|
|
|
|
| 261 |
return False
|
| 262 |
|
| 263 |
|
| 264 |
+
_cleared_lines: set = set() # track lines we've already cleared to prevent loops
|
| 265 |
+
|
| 266 |
+
|
| 267 |
def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
|
| 268 |
"""
|
| 269 |
Recovery strategy: if the last flag was a false positive with high penalty,
|
| 270 |
+
suggest clearing it. Only clears each line ONCE to prevent flag/clear loops.
|
|
|
|
|
|
|
| 271 |
"""
|
| 272 |
if last_reward is None or last_reward >= 0:
|
| 273 |
return None
|
| 274 |
if last_action.get("action_type") != "flag_issue":
|
| 275 |
return None
|
| 276 |
|
| 277 |
+
# Prevent loop: never clear the same line twice
|
| 278 |
+
line_key = (last_action.get("filename"), last_action.get("line_number"))
|
| 279 |
+
if line_key in _cleared_lines:
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
progress = obs.get("progress", {})
|
| 283 |
fp = int(progress.get("false_positives", 0))
|
| 284 |
tp = int(progress.get("true_positives", 0))
|
| 285 |
|
|
|
|
| 286 |
if fp > tp and last_reward <= -0.05:
|
| 287 |
+
_cleared_lines.add(line_key)
|
| 288 |
return {
|
| 289 |
"action_type": "clear_flag",
|
| 290 |
"line_number": last_action.get("line_number"),
|
|
|
|
| 296 |
|
| 297 |
def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
| 298 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 299 |
+
_cleared_lines.clear() # reset per-task
|
| 300 |
+
all_rewards: List[float] = []
|
| 301 |
step_count = 0
|
| 302 |
final_score = 0.0
|
| 303 |
|
|
|
|
| 305 |
resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 306 |
resp.raise_for_status()
|
| 307 |
obs = resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
+
# Show code WITH line numbers β critical for LLM line-counting accuracy
|
| 310 |
+
code_parts = []
|
| 311 |
+
for fname, code in obs.get("code_files", {}).items():
|
| 312 |
+
numbered_lines = "\n".join(
|
| 313 |
+
f"{i+1:3d}| {line}" for i, line in enumerate(code.splitlines())
|
| 314 |
+
)
|
| 315 |
+
code_parts.append(f"=== {fname} ===\n{numbered_lines}")
|
| 316 |
+
code_display = "\n\n".join(code_parts)
|
| 317 |
+
|
| 318 |
+
code_metadata = obs.get("code_metadata") or {}
|
| 319 |
+
function_ranges = code_metadata.get("function_ranges", [])
|
| 320 |
+
fn_map_hint = ""
|
| 321 |
+
if function_ranges:
|
| 322 |
+
fn_lines = [f" {fr['name']}() in {fr['file']} (lines {fr['start']}-{fr['end']})"
|
| 323 |
+
for fr in function_ranges]
|
| 324 |
+
fn_map_hint = "\n\nFunction map:\n" + "\n".join(fn_lines)
|
| 325 |
+
|
| 326 |
+
task_desc = obs.get("task_description", "")
|
| 327 |
+
max_steps = obs.get("max_steps", 20)
|
| 328 |
+
issue_categories = code_metadata.get("issue_categories", [])
|
| 329 |
+
category_hint = ""
|
| 330 |
+
if issue_categories:
|
| 331 |
+
category_hint = f"\nIssue categories to look for: {sorted(set(issue_categories))}"
|
| 332 |
+
|
| 333 |
+
state_features = code_metadata.get("state_features", [])
|
| 334 |
+
complexity_label = "medium"
|
| 335 |
+
if state_features and len(state_features) >= 4:
|
| 336 |
+
complexity_score = state_features[3]
|
| 337 |
+
complexity_label = "high" if complexity_score >= 1.0 else "medium" if complexity_score >= 0.5 else "low"
|
| 338 |
+
|
| 339 |
+
reward_conditioning = (
|
| 340 |
+
f"[TARGET: high-quality review, score β₯ 0.85. "
|
| 341 |
+
f"Code complexity: {complexity_label}. "
|
| 342 |
+
f"Be thorough β missing issues costs more than a single FP.]"
|
| 343 |
+
)
|
| 344 |
|
| 345 |
+
messages = [
|
| 346 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 347 |
+
{
|
| 348 |
+
"role": "user",
|
| 349 |
+
"content": (
|
| 350 |
+
f"{reward_conditioning}\n\n"
|
| 351 |
+
f"Task: {task_desc}\n\n"
|
| 352 |
+
f"{code_display}"
|
| 353 |
+
f"{fn_map_hint}"
|
| 354 |
+
f"{category_hint}\n\n"
|
| 355 |
+
f"You have {max_steps} steps total. "
|
| 356 |
+
f"Work through the checklist systematically, function by function. "
|
| 357 |
+
f"Flag each issue one at a time as a raw JSON object."
|
| 358 |
+
),
|
| 359 |
+
},
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
done = False
|
| 363 |
+
last_action: dict = {}
|
| 364 |
+
last_reward: Optional[float] = None
|
| 365 |
+
|
| 366 |
+
while not done and step_count < max_steps:
|
| 367 |
+
recovery_action = _should_clear_flag(obs, last_reward, last_action)
|
| 368 |
+
if recovery_action and step_count < max_steps - 1:
|
| 369 |
+
action = recovery_action
|
| 370 |
+
action_text = json.dumps(action)
|
| 371 |
+
else:
|
| 372 |
+
try:
|
| 373 |
+
action_text = chat_completion(messages)
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(f"[DEBUG] LLM unavailable ({e})", flush=True)
|
| 376 |
+
try:
|
| 377 |
+
http_client.post(f"{ENV_URL}/step", json={"action_type": "submit_review"}, timeout=30)
|
| 378 |
+
except Exception:
|
| 379 |
+
pass
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
action = parse_action(action_text)
|
| 383 |
+
|
| 384 |
+
if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
|
| 385 |
+
action = {"action_type": "submit_review"}
|
| 386 |
+
action_text = json.dumps(action)
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
try:
|
| 389 |
+
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 390 |
+
step_resp.raise_for_status()
|
| 391 |
+
obs = step_resp.json()
|
| 392 |
except Exception as e:
|
| 393 |
+
step_count += 1
|
| 394 |
+
log_step(step=step_count, action="error", reward=0.0, done=True, error=str(e))
|
| 395 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
+
done = obs.get("done", False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
step_count += 1
|
| 399 |
+
last_reward = obs.get("reward")
|
| 400 |
+
if done:
|
| 401 |
+
final_score = last_reward or obs.get("current_score", 0.0)
|
| 402 |
+
else:
|
| 403 |
+
final_score = obs.get("current_score", 0.0)
|
| 404 |
+
last_action = action
|
| 405 |
+
|
| 406 |
+
# Build feedback for next LLM turn
|
| 407 |
+
progress_feedback = _build_progress_feedback(obs)
|
| 408 |
+
env_feedback = obs.get("feedback", "")
|
| 409 |
+
combined_feedback = env_feedback
|
| 410 |
+
if progress_feedback:
|
| 411 |
+
combined_feedback += f"\n{progress_feedback}"
|
| 412 |
+
|
| 413 |
+
messages.append({"role": "assistant", "content": action_text})
|
| 414 |
+
if combined_feedback:
|
| 415 |
+
messages.append({"role": "user", "content": combined_feedback})
|
| 416 |
+
|
| 417 |
+
max_history = 2 + 24
|
| 418 |
+
if len(messages) > max_history:
|
| 419 |
+
messages = messages[:2] + messages[-(max_history - 2):]
|
| 420 |
+
|
| 421 |
+
atype = action.get("action_type", "")
|
| 422 |
+
reward_val = float(last_reward) if last_reward is not None else 0.0
|
| 423 |
+
all_rewards.append(reward_val)
|
| 424 |
+
action_str = f"{atype}({action.get('filename', '')}:{action.get('line_number', '')})" if atype == "flag_issue" else atype
|
| 425 |
+
log_step(step=step_count, action=action_str, reward=reward_val, done=done, error=None)
|
| 426 |
+
|
| 427 |
+
if atype == "submit_review":
|
| 428 |
+
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
|
| 429 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
+
time.sleep(0.3)
|
|
|
|
|
|
|
| 432 |
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
|
| 435 |
+
finally:
|
| 436 |
+
log_end(
|
| 437 |
+
success=final_score >= 0.5,
|
| 438 |
+
steps=step_count,
|
| 439 |
+
score=final_score,
|
| 440 |
+
rewards=all_rewards,
|
| 441 |
+
)
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
return {
|
| 444 |
"task_id": task_id,
|
| 445 |
"score": float(final_score),
|
|
|
|
| 449 |
|
| 450 |
|
| 451 |
def main():
|
| 452 |
+
use_llm = bool(API_KEY and API_BASE_URL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
try:
|
| 455 |
with httpx.Client(timeout=10) as probe:
|
| 456 |
health = probe.get(f"{ENV_URL}/health")
|
| 457 |
health.raise_for_status()
|
|
|
|
| 458 |
except Exception as e:
|
| 459 |
+
print(f"[DEBUG] Cannot reach environment at {ENV_URL}: {e}", flush=True)
|
| 460 |
sys.exit(1)
|
| 461 |
|
| 462 |
results = {}
|