Spaces:
Sleeping
Sleeping
commit
Browse files- inference.py +27 -44
inference.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
inference.py β Baseline inference script for CodeReview-Env.
|
| 3 |
-
|
| 4 |
-
Mandatory [START] / [STEP] / [END] log format for OpenEnv evaluators.
|
| 5 |
|
| 6 |
Environment variables:
|
| 7 |
API_BASE_URL LLM API base URL
|
|
@@ -31,38 +30,26 @@ SUCCESS_SCORE_THRESHOLD = 0.6
|
|
| 31 |
TASKS = ["easy_syntax", "medium_logic", "hard_security"]
|
| 32 |
|
| 33 |
|
| 34 |
-
# ββ
|
| 35 |
|
| 36 |
def log_start(task: str, env: str, model: str) -> None:
|
| 37 |
-
print(
|
| 38 |
|
| 39 |
|
| 40 |
def log_step(step: int, action: Any, reward: float, done: bool, error: Optional[str] = None) -> None:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
"action": str(action)[:300],
|
| 45 |
-
"reward": reward,
|
| 46 |
-
"done": done,
|
| 47 |
-
"error": error,
|
| 48 |
-
}), flush=True)
|
| 49 |
|
| 50 |
|
| 51 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
"success": success,
|
| 55 |
-
"steps": steps,
|
| 56 |
-
"score": score,
|
| 57 |
-
"rewards": rewards,
|
| 58 |
-
}), flush=True)
|
| 59 |
|
| 60 |
|
| 61 |
# ββ HTTP client for OpenEnv server βββββββββββββββββββββββββββ
|
| 62 |
|
| 63 |
class EnvClient:
|
| 64 |
-
"""Thin HTTP client for the OpenEnv-compliant Space API."""
|
| 65 |
-
|
| 66 |
def __init__(self, base_url: str) -> None:
|
| 67 |
self.base_url = base_url
|
| 68 |
self._http = httpx.Client(timeout=60.0)
|
|
@@ -72,7 +59,7 @@ class EnvClient:
|
|
| 72 |
try:
|
| 73 |
r = self._http.get(f"{self.base_url}/health")
|
| 74 |
if r.status_code == 200:
|
| 75 |
-
print(f"[DEBUG] Server ready
|
| 76 |
return True
|
| 77 |
except Exception as e:
|
| 78 |
print(f"[DEBUG] Waiting for server ({i+1}/{retries}): {e}", flush=True)
|
|
@@ -80,7 +67,6 @@ class EnvClient:
|
|
| 80 |
return False
|
| 81 |
|
| 82 |
def reset(self) -> Dict:
|
| 83 |
-
"""POST /reset β returns {observation, reward, done}"""
|
| 84 |
try:
|
| 85 |
r = self._http.post(f"{self.base_url}/reset")
|
| 86 |
r.raise_for_status()
|
|
@@ -90,11 +76,9 @@ class EnvClient:
|
|
| 90 |
return {"observation": {}, "reward": 0.0, "done": False}
|
| 91 |
|
| 92 |
def step(self, action: Dict) -> Dict:
|
| 93 |
-
"""POST /step with {action: {...}} wrapper β OpenEnv format."""
|
| 94 |
try:
|
| 95 |
-
# OpenEnv create_app requires
|
| 96 |
-
|
| 97 |
-
r = self._http.post(f"{self.base_url}/step", json=payload)
|
| 98 |
r.raise_for_status()
|
| 99 |
return r.json()
|
| 100 |
except Exception as e:
|
|
@@ -112,9 +96,8 @@ class EnvClient:
|
|
| 112 |
|
| 113 |
SYSTEM_PROMPT = """\
|
| 114 |
You are an expert software engineer specialising in code review, debugging, \
|
| 115 |
-
and security auditing.
|
| 116 |
-
|
| 117 |
-
Return ONLY a JSON object in this exact format (no prose, no markdown fences):
|
| 118 |
|
| 119 |
{
|
| 120 |
"identified_issues": [
|
|
@@ -126,7 +109,7 @@ Return ONLY a JSON object in this exact format (no prose, no markdown fences):
|
|
| 126 |
}
|
| 127 |
],
|
| 128 |
"suggested_fix": "<complete corrected code as string, or null>",
|
| 129 |
-
"explanation": "<brief summary of findings>",
|
| 130 |
"submit": true
|
| 131 |
}
|
| 132 |
"""
|
|
@@ -173,7 +156,6 @@ def call_llm(llm_client: OpenAI, prompt: str) -> str:
|
|
| 173 |
|
| 174 |
def parse_llm_output(raw: str) -> Dict:
|
| 175 |
raw = raw.strip()
|
| 176 |
-
# Strip markdown fences if present
|
| 177 |
if raw.startswith("```"):
|
| 178 |
parts = raw.split("```")
|
| 179 |
raw = parts[1] if len(parts) > 1 else raw
|
|
@@ -185,7 +167,7 @@ def parse_llm_output(raw: str) -> Dict:
|
|
| 185 |
return {
|
| 186 |
"identified_issues": [],
|
| 187 |
"suggested_fix": None,
|
| 188 |
-
"explanation": raw[:
|
| 189 |
"submit": True,
|
| 190 |
}
|
| 191 |
|
|
@@ -204,11 +186,10 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
|
|
| 204 |
try:
|
| 205 |
result = env.reset()
|
| 206 |
obs = result.get("observation", {})
|
| 207 |
-
max_steps = obs.get("max_steps", 5)
|
| 208 |
|
| 209 |
for step in range(1, max_steps + 1):
|
| 210 |
-
|
| 211 |
-
if done:
|
| 212 |
break
|
| 213 |
|
| 214 |
prompt = build_prompt(obs, step, prev_feedback)
|
|
@@ -226,7 +207,7 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
|
|
| 226 |
|
| 227 |
log_step(
|
| 228 |
step=step,
|
| 229 |
-
action=action.get("explanation", "")[:
|
| 230 |
reward=reward,
|
| 231 |
done=done,
|
| 232 |
error=None,
|
|
@@ -241,6 +222,7 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
|
|
| 241 |
|
| 242 |
except Exception as e:
|
| 243 |
print(f"[DEBUG] run_task error: {e}", flush=True)
|
|
|
|
| 244 |
score = 0.0
|
| 245 |
success = False
|
| 246 |
|
|
@@ -253,34 +235,35 @@ def run_task(task_id: str, env: EnvClient, llm: OpenAI) -> float:
|
|
| 253 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 254 |
|
| 255 |
def main() -> None:
|
| 256 |
-
print(f"[DEBUG] Starting
|
| 257 |
|
| 258 |
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 259 |
env = EnvClient(SPACE_URL)
|
| 260 |
|
| 261 |
if not env.wait_until_ready():
|
| 262 |
print("[ERROR] Server not reachable. Exiting.", flush=True)
|
|
|
|
|
|
|
|
|
|
| 263 |
sys.exit(1)
|
| 264 |
|
| 265 |
task_scores: Dict[str, float] = {}
|
| 266 |
|
| 267 |
for task_id in TASKS:
|
| 268 |
-
print(f"\n
|
| 269 |
try:
|
| 270 |
task_scores[task_id] = run_task(task_id, env, llm)
|
| 271 |
except Exception as e:
|
| 272 |
print(f"[DEBUG] Task {task_id} crashed: {e}", flush=True)
|
|
|
|
|
|
|
| 273 |
task_scores[task_id] = 0.0
|
| 274 |
time.sleep(1)
|
| 275 |
|
| 276 |
env.close()
|
| 277 |
|
| 278 |
-
print(f"\n{'='*50}\nFINAL SCORES\n{'='*50}", flush=True)
|
| 279 |
-
for tid, s in task_scores.items():
|
| 280 |
-
status = "PASS" if s >= SUCCESS_SCORE_THRESHOLD else "FAIL"
|
| 281 |
-
print(f" {tid:25s}: {s:.4f} [{status}]", flush=True)
|
| 282 |
overall = sum(task_scores.values()) / len(task_scores)
|
| 283 |
-
print(f"\n
|
| 284 |
|
| 285 |
|
| 286 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
inference.py β Baseline inference script for CodeReview-Env.
|
| 3 |
+
Uses required [START] / [STEP] / [END] plain-text log format.
|
|
|
|
| 4 |
|
| 5 |
Environment variables:
|
| 6 |
API_BASE_URL LLM API base URL
|
|
|
|
| 30 |
TASKS = ["easy_syntax", "medium_logic", "hard_security"]
|
| 31 |
|
| 32 |
|
| 33 |
+
# ββ MANDATORY log format: plain text [START]/[STEP]/[END] βββββ
|
| 34 |
|
| 35 |
def log_start(task: str, env: str, model: str) -> None:
|
| 36 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 37 |
|
| 38 |
|
| 39 |
def log_step(step: int, action: Any, reward: float, done: bool, error: Optional[str] = None) -> None:
|
| 40 |
+
action_str = str(action)[:100].replace("\n", " ")
|
| 41 |
+
error_str = error if error else "null"
|
| 42 |
+
print(f"[STEP] step={step} action={action_str} reward={reward} done={done} error={error_str}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 46 |
+
rewards_str = str([round(r, 4) for r in rewards])
|
| 47 |
+
print(f"[END] success={success} steps={steps} score={score} rewards={rewards_str}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
# ββ HTTP client for OpenEnv server βββββββββββββββββββββββββββ
|
| 51 |
|
| 52 |
class EnvClient:
|
|
|
|
|
|
|
| 53 |
def __init__(self, base_url: str) -> None:
|
| 54 |
self.base_url = base_url
|
| 55 |
self._http = httpx.Client(timeout=60.0)
|
|
|
|
| 59 |
try:
|
| 60 |
r = self._http.get(f"{self.base_url}/health")
|
| 61 |
if r.status_code == 200:
|
| 62 |
+
print(f"[DEBUG] Server ready", flush=True)
|
| 63 |
return True
|
| 64 |
except Exception as e:
|
| 65 |
print(f"[DEBUG] Waiting for server ({i+1}/{retries}): {e}", flush=True)
|
|
|
|
| 67 |
return False
|
| 68 |
|
| 69 |
def reset(self) -> Dict:
|
|
|
|
| 70 |
try:
|
| 71 |
r = self._http.post(f"{self.base_url}/reset")
|
| 72 |
r.raise_for_status()
|
|
|
|
| 76 |
return {"observation": {}, "reward": 0.0, "done": False}
|
| 77 |
|
| 78 |
def step(self, action: Dict) -> Dict:
|
|
|
|
| 79 |
try:
|
| 80 |
+
# OpenEnv create_app requires: {"action": {...}}
|
| 81 |
+
r = self._http.post(f"{self.base_url}/step", json={"action": action})
|
|
|
|
| 82 |
r.raise_for_status()
|
| 83 |
return r.json()
|
| 84 |
except Exception as e:
|
|
|
|
| 96 |
|
| 97 |
SYSTEM_PROMPT = """\
|
| 98 |
You are an expert software engineer specialising in code review, debugging, \
|
| 99 |
+
and security auditing. Analyse the code and return ONLY a JSON object \
|
| 100 |
+
(no prose, no markdown fences):
|
|
|
|
| 101 |
|
| 102 |
{
|
| 103 |
"identified_issues": [
|
|
|
|
| 109 |
}
|
| 110 |
],
|
| 111 |
"suggested_fix": "<complete corrected code as string, or null>",
|
| 112 |
+
"explanation": "<brief summary of all findings>",
|
| 113 |
"submit": true
|
| 114 |
}
|
| 115 |
"""
|
|
|
|
| 156 |
|
| 157 |
def parse_llm_output(raw: str) -> Dict:
|
| 158 |
raw = raw.strip()
|
|
|
|
| 159 |
if raw.startswith("```"):
|
| 160 |
parts = raw.split("```")
|
| 161 |
raw = parts[1] if len(parts) > 1 else raw
|
|
|
|
| 167 |
return {
|
| 168 |
"identified_issues": [],
|
| 169 |
"suggested_fix": None,
|
| 170 |
+
"explanation": raw[:200],
|
| 171 |
"submit": True,
|
| 172 |
}
|
| 173 |
|
|
|
|
| 186 |
try:
|
| 187 |
result = env.reset()
|
| 188 |
obs = result.get("observation", {})
|
| 189 |
+
max_steps = int(obs.get("max_steps", 5))
|
| 190 |
|
| 191 |
for step in range(1, max_steps + 1):
|
| 192 |
+
if result.get("done", False):
|
|
|
|
| 193 |
break
|
| 194 |
|
| 195 |
prompt = build_prompt(obs, step, prev_feedback)
|
|
|
|
| 207 |
|
| 208 |
log_step(
|
| 209 |
step=step,
|
| 210 |
+
action=action.get("explanation", "")[:100],
|
| 211 |
reward=reward,
|
| 212 |
done=done,
|
| 213 |
error=None,
|
|
|
|
| 222 |
|
| 223 |
except Exception as e:
|
| 224 |
print(f"[DEBUG] run_task error: {e}", flush=True)
|
| 225 |
+
log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
|
| 226 |
score = 0.0
|
| 227 |
success = False
|
| 228 |
|
|
|
|
| 235 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
|
| 237 |
def main() -> None:
|
| 238 |
+
print(f"[DEBUG] Starting. SPACE_URL={SPACE_URL} MODEL={MODEL_NAME}", flush=True)
|
| 239 |
|
| 240 |
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 241 |
env = EnvClient(SPACE_URL)
|
| 242 |
|
| 243 |
if not env.wait_until_ready():
|
| 244 |
print("[ERROR] Server not reachable. Exiting.", flush=True)
|
| 245 |
+
for task_id in TASKS:
|
| 246 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 247 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 248 |
sys.exit(1)
|
| 249 |
|
| 250 |
task_scores: Dict[str, float] = {}
|
| 251 |
|
| 252 |
for task_id in TASKS:
|
| 253 |
+
print(f"\n[DEBUG] ===== Running task: {task_id} =====", flush=True)
|
| 254 |
try:
|
| 255 |
task_scores[task_id] = run_task(task_id, env, llm)
|
| 256 |
except Exception as e:
|
| 257 |
print(f"[DEBUG] Task {task_id} crashed: {e}", flush=True)
|
| 258 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 259 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 260 |
task_scores[task_id] = 0.0
|
| 261 |
time.sleep(1)
|
| 262 |
|
| 263 |
env.close()
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
overall = sum(task_scores.values()) / len(task_scores)
|
| 266 |
+
print(f"\n[DEBUG] Overall average: {overall:.4f}", flush=True)
|
| 267 |
|
| 268 |
|
| 269 |
if __name__ == "__main__":
|