vedkdev commited on
Commit
bbd2278
·
verified ·
1 Parent(s): dbfc631

Deploy FlakyGym UI + inference updates (minimal upload)

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -1
  2. inference.py +183 -12
  3. inference_debug.py +183 -12
Dockerfile CHANGED
@@ -14,5 +14,4 @@ COPY . .
14
 
15
  EXPOSE 8000
16
 
17
- ENV ENABLE_WEB_INTERFACE=true
18
  CMD ["python", "-m", "server.app"]
 
14
 
15
  EXPOSE 8000
16
 
 
17
  CMD ["python", "-m", "server.app"]
inference.py CHANGED
@@ -78,6 +78,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL)
78
  EPISODES_PER_TASK = 2
79
  MAX_STEPS = 20
80
  MEMORY_MAX_CHARS = 900
 
81
 
82
  client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
83
 
@@ -123,6 +124,144 @@ def _short_error(text: str, max_chars: int = 220) -> str:
123
  return f"{one_line[:max_chars]}...[truncated {hidden} chars]"
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def _compliance_log_start(task: str, benchmark: str, model: str) -> None:
127
  print(f"[START] task={task} env={benchmark} model={model}", flush=True)
128
 
@@ -215,22 +354,52 @@ def llm_action(
215
  "attempted": False,
216
  "raw_output": "",
217
  "error": "",
 
 
218
  }
219
  if not API_KEY:
 
220
  return None, meta
221
 
222
- meta["attempted"] = True
223
- response = client.chat.completions.create(
224
- model=MODEL_NAME,
225
- messages=messages,
226
- max_tokens=400,
227
- temperature=0.0,
228
- )
229
- raw = (response.choices[0].message.content or "").strip()
230
- meta["raw_output"] = raw
231
- cleaned = raw.replace("```json", "").replace("```", "").strip()
232
- payload = json.loads(cleaned)
233
- return FlakySleuthAction.model_validate(payload), meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  def _clip_text(text: str, max_chars: int) -> str:
@@ -477,6 +646,8 @@ def run_episode(
477
  heuristic_steps += 1
478
  if not API_KEY:
479
  reason_key = "no_api_key"
 
 
480
  elif llm_meta.get("error"):
481
  reason_key = "llm_error"
482
  elif llm_meta.get("attempted"):
 
78
  EPISODES_PER_TASK = 2
79
  MAX_STEPS = 20
80
  MEMORY_MAX_CHARS = 900
81
+ LLM_MAX_RETRIES = 2
82
 
83
  client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
84
 
 
124
  return f"{one_line[:max_chars]}...[truncated {hidden} chars]"
125
 
126
 
127
+ class _ActionParseError(Exception):
128
+ def __init__(self, reason: str, detail: str) -> None:
129
+ super().__init__(f"{reason}: {detail}")
130
+ self.reason = reason
131
+ self.detail = detail
132
+
133
+
134
+ def _strip_code_fences(text: str) -> str:
135
+ stripped = text.strip()
136
+ if stripped.startswith("```"):
137
+ lines = stripped.splitlines()
138
+ if lines:
139
+ lines = lines[1:]
140
+ if lines and lines[-1].strip() == "```":
141
+ lines = lines[:-1]
142
+ stripped = "\n".join(lines).strip()
143
+ if stripped.lower().startswith("json\n"):
144
+ stripped = stripped[5:].strip()
145
+ return stripped
146
+
147
+
148
+ def _extract_first_json_object(text: str) -> str | None:
149
+ start = -1
150
+ depth = 0
151
+ in_string = False
152
+ escaped = False
153
+ for idx, ch in enumerate(text):
154
+ if in_string:
155
+ if escaped:
156
+ escaped = False
157
+ continue
158
+ if ch == "\\":
159
+ escaped = True
160
+ continue
161
+ if ch == '"':
162
+ in_string = False
163
+ continue
164
+
165
+ if ch == '"':
166
+ in_string = True
167
+ continue
168
+ if ch == "{":
169
+ if depth == 0:
170
+ start = idx
171
+ depth += 1
172
+ continue
173
+ if ch == "}":
174
+ if depth == 0:
175
+ continue
176
+ depth -= 1
177
+ if depth == 0 and start >= 0:
178
+ return text[start : idx + 1]
179
+ return None
180
+
181
+
182
+ def _parse_action_payload(raw: str) -> tuple[FlakySleuthAction, str]:
183
+ raw_text = (raw or "").strip()
184
+ if not raw_text:
185
+ raise _ActionParseError("llm_empty_output", "empty response body")
186
+
187
+ candidates: list[str] = []
188
+ seen: set[str] = set()
189
+
190
+ def add_candidate(value: str | None) -> None:
191
+ if value is None:
192
+ return
193
+ cleaned = value.strip()
194
+ if not cleaned or cleaned in seen:
195
+ return
196
+ seen.add(cleaned)
197
+ candidates.append(cleaned)
198
+
199
+ add_candidate(raw_text)
200
+ stripped = _strip_code_fences(raw_text)
201
+ add_candidate(stripped)
202
+ add_candidate(_extract_first_json_object(stripped))
203
+ add_candidate(_extract_first_json_object(raw_text))
204
+
205
+ json_errors: list[str] = []
206
+ schema_errors: list[str] = []
207
+
208
+ for candidate in candidates:
209
+ try:
210
+ payload = json.loads(candidate)
211
+ except json.JSONDecodeError as exc:
212
+ json_errors.append(str(exc))
213
+ continue
214
+ if not isinstance(payload, dict):
215
+ schema_errors.append(f"top-level JSON must be an object, got {type(payload).__name__}")
216
+ continue
217
+ try:
218
+ action = FlakySleuthAction.model_validate(payload)
219
+ except Exception as exc:
220
+ schema_errors.append(str(exc))
221
+ continue
222
+ return action, candidate
223
+
224
+ if schema_errors:
225
+ raise _ActionParseError("llm_schema_error", _short_error(schema_errors[-1], max_chars=300))
226
+ if json_errors:
227
+ raise _ActionParseError("llm_json_parse_error", _short_error(json_errors[-1], max_chars=300))
228
+ raise _ActionParseError("llm_json_parse_error", "unable to extract JSON object")
229
+
230
+
231
+ def _json_repair_prompt(error_text: str, raw_output: str) -> str:
232
+ clipped_raw = _short_error(raw_output or "(empty)", max_chars=300)
233
+ clipped_err = _short_error(error_text, max_chars=260)
234
+ return (
235
+ "Your previous response was invalid.\n"
236
+ f"Parser error: {clipped_err}\n"
237
+ f"Previous output (truncated): {clipped_raw}\n"
238
+ "Respond again with ONLY one valid JSON object and no extra text.\n"
239
+ 'Required schema: {"action_type": "<one valid action>", "argument": "<string>", "metadata": {}}\n'
240
+ 'Do NOT wrap in markdown fences. Do NOT add commentary.'
241
+ )
242
+
243
+
244
+ def _chat_completion_request(messages: list[dict[str, str]]) -> Any:
245
+ base_kwargs = {
246
+ "model": MODEL_NAME,
247
+ "messages": messages,
248
+ "max_tokens": 400,
249
+ "temperature": 0.0,
250
+ }
251
+ try:
252
+ return client.chat.completions.create(
253
+ response_format={"type": "json_object"},
254
+ **base_kwargs,
255
+ )
256
+ except Exception as json_mode_exc:
257
+ try:
258
+ return client.chat.completions.create(**base_kwargs)
259
+ except Exception as plain_mode_exc:
260
+ raise RuntimeError(
261
+ f"json_mode_error={json_mode_exc}; plain_mode_error={plain_mode_exc}"
262
+ ) from plain_mode_exc
263
+
264
+
265
  def _compliance_log_start(task: str, benchmark: str, model: str) -> None:
266
  print(f"[START] task={task} env={benchmark} model={model}", flush=True)
267
 
 
354
  "attempted": False,
355
  "raw_output": "",
356
  "error": "",
357
+ "reason": "",
358
+ "attempt_count": 0,
359
  }
360
  if not API_KEY:
361
+ meta["reason"] = "no_api_key"
362
  return None, meta
363
 
364
+ work_messages = list(messages)
365
+ last_error = ""
366
+ for attempt in range(LLM_MAX_RETRIES + 1):
367
+ meta["attempted"] = True
368
+ meta["attempt_count"] = attempt + 1
369
+ try:
370
+ response = _chat_completion_request(work_messages)
371
+ except Exception as exc:
372
+ last_error = f"request_failed attempt={attempt + 1}: {exc}"
373
+ meta["error"] = _short_error(last_error, max_chars=500)
374
+ meta["reason"] = "llm_http_error"
375
+ if attempt < LLM_MAX_RETRIES:
376
+ work_messages = work_messages + [
377
+ {"role": "user", "content": _json_repair_prompt(last_error, "")}
378
+ ]
379
+ continue
380
+ return None, meta
381
+
382
+ raw = (response.choices[0].message.content or "").strip()
383
+ meta["raw_output"] = raw
384
+ try:
385
+ action, _ = _parse_action_payload(raw)
386
+ meta["error"] = ""
387
+ meta["reason"] = "ok"
388
+ return action, meta
389
+ except _ActionParseError as exc:
390
+ last_error = f"{exc.reason}: {exc.detail}"
391
+ meta["error"] = _short_error(last_error, max_chars=500)
392
+ meta["reason"] = exc.reason
393
+ if attempt < LLM_MAX_RETRIES:
394
+ work_messages = work_messages + [
395
+ {"role": "user", "content": _json_repair_prompt(last_error, raw)}
396
+ ]
397
+ continue
398
+ return None, meta
399
+
400
+ meta["error"] = _short_error(last_error or "unknown llm failure", max_chars=500)
401
+ meta["reason"] = meta["reason"] or "llm_error"
402
+ return None, meta
403
 
404
 
405
  def _clip_text(text: str, max_chars: int) -> str:
 
646
  heuristic_steps += 1
647
  if not API_KEY:
648
  reason_key = "no_api_key"
649
+ elif llm_meta.get("reason") and llm_meta.get("reason") != "ok":
650
+ reason_key = str(llm_meta.get("reason"))
651
  elif llm_meta.get("error"):
652
  reason_key = "llm_error"
653
  elif llm_meta.get("attempted"):
inference_debug.py CHANGED
@@ -76,6 +76,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL)
76
  EPISODES_PER_TASK = 2
77
  MAX_STEPS = 20
78
  MEMORY_MAX_CHARS = 900
 
79
 
80
  client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
81
 
@@ -121,6 +122,144 @@ def _short_error(text: str, max_chars: int = 220) -> str:
121
  return f"{one_line[:max_chars]}...[truncated {hidden} chars]"
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def _compliance_log_start(task: str, benchmark: str, model: str) -> None:
125
  print(f"[START] task={task} env={benchmark} model={model}", flush=True)
126
 
@@ -208,22 +347,52 @@ def llm_action(
208
  "attempted": False,
209
  "raw_output": "",
210
  "error": "",
 
 
211
  }
212
  if not API_KEY:
 
213
  return None, meta
214
 
215
- meta["attempted"] = True
216
- response = client.chat.completions.create(
217
- model=MODEL_NAME,
218
- messages=messages,
219
- max_tokens=400,
220
- temperature=0.0,
221
- )
222
- raw = (response.choices[0].message.content or "").strip()
223
- meta["raw_output"] = raw
224
- cleaned = raw.replace("```json", "").replace("```", "").strip()
225
- payload = json.loads(cleaned)
226
- return FlakySleuthAction.model_validate(payload), meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  def _clip_text(text: str, max_chars: int) -> str:
@@ -470,6 +639,8 @@ def run_episode(
470
  heuristic_steps += 1
471
  if not API_KEY:
472
  reason_key = "no_api_key"
 
 
473
  elif llm_meta.get("error"):
474
  reason_key = "llm_error"
475
  elif llm_meta.get("attempted"):
 
76
  EPISODES_PER_TASK = 2
77
  MAX_STEPS = 20
78
  MEMORY_MAX_CHARS = 900
79
+ LLM_MAX_RETRIES = 2
80
 
81
  client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
82
 
 
122
  return f"{one_line[:max_chars]}...[truncated {hidden} chars]"
123
 
124
 
125
+ class _ActionParseError(Exception):
126
+ def __init__(self, reason: str, detail: str) -> None:
127
+ super().__init__(f"{reason}: {detail}")
128
+ self.reason = reason
129
+ self.detail = detail
130
+
131
+
132
+ def _strip_code_fences(text: str) -> str:
133
+ stripped = text.strip()
134
+ if stripped.startswith("```"):
135
+ lines = stripped.splitlines()
136
+ if lines:
137
+ lines = lines[1:]
138
+ if lines and lines[-1].strip() == "```":
139
+ lines = lines[:-1]
140
+ stripped = "\n".join(lines).strip()
141
+ if stripped.lower().startswith("json\n"):
142
+ stripped = stripped[5:].strip()
143
+ return stripped
144
+
145
+
146
+ def _extract_first_json_object(text: str) -> str | None:
147
+ start = -1
148
+ depth = 0
149
+ in_string = False
150
+ escaped = False
151
+ for idx, ch in enumerate(text):
152
+ if in_string:
153
+ if escaped:
154
+ escaped = False
155
+ continue
156
+ if ch == "\\":
157
+ escaped = True
158
+ continue
159
+ if ch == '"':
160
+ in_string = False
161
+ continue
162
+
163
+ if ch == '"':
164
+ in_string = True
165
+ continue
166
+ if ch == "{":
167
+ if depth == 0:
168
+ start = idx
169
+ depth += 1
170
+ continue
171
+ if ch == "}":
172
+ if depth == 0:
173
+ continue
174
+ depth -= 1
175
+ if depth == 0 and start >= 0:
176
+ return text[start : idx + 1]
177
+ return None
178
+
179
+
180
+ def _parse_action_payload(raw: str) -> tuple[FlakySleuthAction, str]:
181
+ raw_text = (raw or "").strip()
182
+ if not raw_text:
183
+ raise _ActionParseError("llm_empty_output", "empty response body")
184
+
185
+ candidates: list[str] = []
186
+ seen: set[str] = set()
187
+
188
+ def add_candidate(value: str | None) -> None:
189
+ if value is None:
190
+ return
191
+ cleaned = value.strip()
192
+ if not cleaned or cleaned in seen:
193
+ return
194
+ seen.add(cleaned)
195
+ candidates.append(cleaned)
196
+
197
+ add_candidate(raw_text)
198
+ stripped = _strip_code_fences(raw_text)
199
+ add_candidate(stripped)
200
+ add_candidate(_extract_first_json_object(stripped))
201
+ add_candidate(_extract_first_json_object(raw_text))
202
+
203
+ json_errors: list[str] = []
204
+ schema_errors: list[str] = []
205
+
206
+ for candidate in candidates:
207
+ try:
208
+ payload = json.loads(candidate)
209
+ except json.JSONDecodeError as exc:
210
+ json_errors.append(str(exc))
211
+ continue
212
+ if not isinstance(payload, dict):
213
+ schema_errors.append(f"top-level JSON must be an object, got {type(payload).__name__}")
214
+ continue
215
+ try:
216
+ action = FlakySleuthAction.model_validate(payload)
217
+ except Exception as exc:
218
+ schema_errors.append(str(exc))
219
+ continue
220
+ return action, candidate
221
+
222
+ if schema_errors:
223
+ raise _ActionParseError("llm_schema_error", _short_error(schema_errors[-1], max_chars=300))
224
+ if json_errors:
225
+ raise _ActionParseError("llm_json_parse_error", _short_error(json_errors[-1], max_chars=300))
226
+ raise _ActionParseError("llm_json_parse_error", "unable to extract JSON object")
227
+
228
+
229
+ def _json_repair_prompt(error_text: str, raw_output: str) -> str:
230
+ clipped_raw = _short_error(raw_output or "(empty)", max_chars=300)
231
+ clipped_err = _short_error(error_text, max_chars=260)
232
+ return (
233
+ "Your previous response was invalid.\n"
234
+ f"Parser error: {clipped_err}\n"
235
+ f"Previous output (truncated): {clipped_raw}\n"
236
+ "Respond again with ONLY one valid JSON object and no extra text.\n"
237
+ 'Required schema: {"action_type": "<one valid action>", "argument": "<string>", "metadata": {}}\n'
238
+ 'Do NOT wrap in markdown fences. Do NOT add commentary.'
239
+ )
240
+
241
+
242
+ def _chat_completion_request(messages: list[dict[str, str]]) -> Any:
243
+ base_kwargs = {
244
+ "model": MODEL_NAME,
245
+ "messages": messages,
246
+ "max_tokens": 400,
247
+ "temperature": 0.0,
248
+ }
249
+ try:
250
+ return client.chat.completions.create(
251
+ response_format={"type": "json_object"},
252
+ **base_kwargs,
253
+ )
254
+ except Exception as json_mode_exc:
255
+ try:
256
+ return client.chat.completions.create(**base_kwargs)
257
+ except Exception as plain_mode_exc:
258
+ raise RuntimeError(
259
+ f"json_mode_error={json_mode_exc}; plain_mode_error={plain_mode_exc}"
260
+ ) from plain_mode_exc
261
+
262
+
263
  def _compliance_log_start(task: str, benchmark: str, model: str) -> None:
264
  print(f"[START] task={task} env={benchmark} model={model}", flush=True)
265
 
 
347
  "attempted": False,
348
  "raw_output": "",
349
  "error": "",
350
+ "reason": "",
351
+ "attempt_count": 0,
352
  }
353
  if not API_KEY:
354
+ meta["reason"] = "no_api_key"
355
  return None, meta
356
 
357
+ work_messages = list(messages)
358
+ last_error = ""
359
+ for attempt in range(LLM_MAX_RETRIES + 1):
360
+ meta["attempted"] = True
361
+ meta["attempt_count"] = attempt + 1
362
+ try:
363
+ response = _chat_completion_request(work_messages)
364
+ except Exception as exc:
365
+ last_error = f"request_failed attempt={attempt + 1}: {exc}"
366
+ meta["error"] = _short_error(last_error, max_chars=500)
367
+ meta["reason"] = "llm_http_error"
368
+ if attempt < LLM_MAX_RETRIES:
369
+ work_messages = work_messages + [
370
+ {"role": "user", "content": _json_repair_prompt(last_error, "")}
371
+ ]
372
+ continue
373
+ return None, meta
374
+
375
+ raw = (response.choices[0].message.content or "").strip()
376
+ meta["raw_output"] = raw
377
+ try:
378
+ action, _ = _parse_action_payload(raw)
379
+ meta["error"] = ""
380
+ meta["reason"] = "ok"
381
+ return action, meta
382
+ except _ActionParseError as exc:
383
+ last_error = f"{exc.reason}: {exc.detail}"
384
+ meta["error"] = _short_error(last_error, max_chars=500)
385
+ meta["reason"] = exc.reason
386
+ if attempt < LLM_MAX_RETRIES:
387
+ work_messages = work_messages + [
388
+ {"role": "user", "content": _json_repair_prompt(last_error, raw)}
389
+ ]
390
+ continue
391
+ return None, meta
392
+
393
+ meta["error"] = _short_error(last_error or "unknown llm failure", max_chars=500)
394
+ meta["reason"] = meta["reason"] or "llm_error"
395
+ return None, meta
396
 
397
 
398
  def _clip_text(text: str, max_chars: int) -> str:
 
639
  heuristic_steps += 1
640
  if not API_KEY:
641
  reason_key = "no_api_key"
642
+ elif llm_meta.get("reason") and llm_meta.get("reason") != "ok":
643
+ reason_key = str(llm_meta.get("reason"))
644
  elif llm_meta.get("error"):
645
  reason_key = "llm_error"
646
  elif llm_meta.get("attempted"):