Hassan Shaikh commited on
Commit
ad262b3
·
1 Parent(s): 43d8ac0

fix: harden inference against validator runtime failures

Browse files
Files changed (1) hide show
  1. inference.py +40 -9
inference.py CHANGED
@@ -19,6 +19,8 @@ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
19
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
20
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
21
  ENV_BASE_URL = os.getenv("ENV_BASE_URL")
 
 
22
  TASK_IDS = [t.strip() for t in os.getenv("TASK_IDS", "easy,medium,hard").split(",") if t.strip()]
23
  MAX_STEPS = int(os.getenv("MAX_STEPS", "12"))
24
  TEMPERATURE = 0.0
@@ -65,6 +67,13 @@ def _default_action() -> Dict[str, Any]:
65
  }
66
 
67
 
 
 
 
 
 
 
 
68
  def _parse_action(raw: str, available_files: List[str]) -> Dict[str, Any]:
69
  try:
70
  parsed = json.loads(raw)
@@ -145,13 +154,17 @@ def _query_model(client: OpenAI, obs: Any, step: int) -> Dict[str, Any]:
145
 
146
 
147
  async def _create_env() -> CodeSecurityAuditorEnv:
148
- if LOCAL_IMAGE_NAME:
149
- return await CodeSecurityAuditorEnv.from_docker_image(LOCAL_IMAGE_NAME)
150
  if ENV_BASE_URL:
151
  return CodeSecurityAuditorEnv(base_url=ENV_BASE_URL)
152
- raise RuntimeError(
153
- "Set LOCAL_IMAGE_NAME (docker mode) or ENV_BASE_URL (remote mode) to run inference."
154
- )
 
 
 
 
 
155
 
156
 
157
  async def run_task(env: CodeSecurityAuditorEnv, client: OpenAI, task_id: str) -> float:
@@ -191,6 +204,14 @@ async def run_task(env: CodeSecurityAuditorEnv, client: OpenAI, task_id: str) ->
191
  score = float(obs.reward or 0.0)
192
  score = min(max(score, 0.0), 1.0)
193
  success = score >= 0.6
 
 
 
 
 
 
 
 
194
  finally:
195
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
196
 
@@ -198,11 +219,21 @@ async def run_task(env: CodeSecurityAuditorEnv, client: OpenAI, task_id: str) ->
198
 
199
 
200
  async def main() -> None:
201
- if not API_KEY:
202
- raise RuntimeError("HF_TOKEN (or API_KEY) is required for inference.")
203
 
204
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
205
- env = await _create_env()
 
 
 
 
 
 
 
 
 
 
206
 
207
  try:
208
  scores: List[float] = []
 
19
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
20
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
21
  ENV_BASE_URL = os.getenv("ENV_BASE_URL")
22
+ DEFAULT_ENV_BASE_URL = os.getenv("DEFAULT_ENV_BASE_URL", "http://127.0.0.1:8000")
23
+ DEFAULT_LOCAL_IMAGE_NAME = os.getenv("DEFAULT_LOCAL_IMAGE_NAME", "code-security-auditor-env:latest")
24
  TASK_IDS = [t.strip() for t in os.getenv("TASK_IDS", "easy,medium,hard").split(",") if t.strip()]
25
  MAX_STEPS = int(os.getenv("MAX_STEPS", "12"))
26
  TEMPERATURE = 0.0
 
67
  }
68
 
69
 
70
+ def _safe_error(exc: Exception) -> str:
71
+ msg = str(exc).strip()
72
+ if not msg:
73
+ msg = exc.__class__.__name__
74
+ return msg.replace("\n", " ")[:240]
75
+
76
+
77
  def _parse_action(raw: str, available_files: List[str]) -> Dict[str, Any]:
78
  try:
79
  parsed = json.loads(raw)
 
154
 
155
 
156
  async def _create_env() -> CodeSecurityAuditorEnv:
157
+ # Prefer explicit configuration, then fall back to common local defaults.
 
158
  if ENV_BASE_URL:
159
  return CodeSecurityAuditorEnv(base_url=ENV_BASE_URL)
160
+
161
+ if LOCAL_IMAGE_NAME:
162
+ return await CodeSecurityAuditorEnv.from_docker_image(LOCAL_IMAGE_NAME)
163
+
164
+ try:
165
+ return CodeSecurityAuditorEnv(base_url=DEFAULT_ENV_BASE_URL)
166
+ except Exception:
167
+ return await CodeSecurityAuditorEnv.from_docker_image(DEFAULT_LOCAL_IMAGE_NAME)
168
 
169
 
170
  async def run_task(env: CodeSecurityAuditorEnv, client: OpenAI, task_id: str) -> float:
 
204
  score = float(obs.reward or 0.0)
205
  score = min(max(score, 0.0), 1.0)
206
  success = score >= 0.6
207
+ except Exception as exc:
208
+ # Keep evaluator contract: do not crash inference.py on transient/runtime errors.
209
+ log_step(step=max(1, steps_taken), action="{}", reward=0.0, done=True, error=_safe_error(exc))
210
+ if not rewards:
211
+ rewards.append(0.0)
212
+ steps_taken = max(1, steps_taken)
213
+ score = 0.0
214
+ success = False
215
  finally:
216
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
217
 
 
219
 
220
 
221
  async def main() -> None:
222
+ # Keep script resilient in validators even if a key is temporarily unavailable.
223
+ api_key = API_KEY or "missing"
224
 
225
+ client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
226
+
227
+ try:
228
+ env = await _create_env()
229
+ except Exception as exc:
230
+ # Emit structured logs for each task and exit cleanly.
231
+ err = _safe_error(exc)
232
+ for task_id in TASK_IDS:
233
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
234
+ log_step(step=1, action="{}", reward=0.0, done=True, error=err)
235
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
236
+ return
237
 
238
  try:
239
  scores: List[float] = []