Mist-ic commited on
Commit
382d0fd
·
1 Parent(s): b971f92

Fix Phase 2 timeout: cap LLM retries, move server wait to main, add global time guard

Browse files
Files changed (1) hide show
  1. inference.py +54 -23
inference.py CHANGED
@@ -117,25 +117,28 @@ def _track_usage(completion: Any) -> None:
117
  # ---------------------------------------------------------------------------
118
 
119
 
 
 
 
120
  def _call_llm(messages: List[Dict[str, Any]], client: OpenAI) -> str:
121
- """Call the LLM with exponential backoff retry. Returns raw response text."""
122
- attempt = 0
123
- while True:
124
  try:
125
  completion = client.chat.completions.create(
126
  model=MODEL_NAME,
127
  messages=messages,
128
  temperature=0,
129
- max_tokens=1024,
 
130
  )
131
  _track_usage(completion)
132
  return completion.choices[0].message.content or ""
133
  except Exception as e:
134
- attempt += 1
135
- wait = min(10 * (2 ** (attempt - 1)), 60)
136
- print(f" [attempt {attempt}] {MODEL_NAME} error: {e}", flush=True)
137
- print(f" [retry] waiting {wait}s...", flush=True)
138
- time.sleep(wait)
139
 
140
 
141
  # ---------------------------------------------------------------------------
@@ -236,19 +239,20 @@ def parse_action(response_text: str) -> Dict[str, Any]:
236
  # ---------------------------------------------------------------------------
237
 
238
 
239
- def _wait_for_server(base: str, max_wait: int = 90) -> None:
240
  """Poll /health until server is ready or timeout."""
241
- import httpx, time
242
  deadline = time.time() + max_wait
243
  while time.time() < deadline:
244
  try:
245
  r = httpx.get(f"{base}/health", timeout=5.0)
246
  if r.status_code == 200:
 
247
  return
248
  except Exception:
249
  pass
250
- time.sleep(3)
251
- raise RuntimeError(f"Server at {base} not ready after {max_wait}s")
252
 
253
 
254
  def run_episode(
@@ -260,16 +264,21 @@ def run_episode(
260
 
261
  base = ENV_URL.rstrip("/")
262
 
263
- # Wait for server to be ready (handles startup race condition)
264
- _wait_for_server(base)
265
-
266
  # Reset environment
267
- reset_resp = httpx.post(
268
- f"{base}/reset",
269
- json={"seed": seed, "task_id": task_id},
270
- timeout=30.0,
271
- )
272
- resp_data = reset_resp.json()
 
 
 
 
 
 
 
 
273
  obs = resp_data.get("observation", resp_data)
274
 
275
  max_steps = obs.get("max_steps", 10)
@@ -297,6 +306,9 @@ def run_episode(
297
  for step_num in range(1, max_steps + 1):
298
  if done:
299
  break
 
 
 
300
 
301
  user_msg = build_observation_prompt(obs)
302
  conversation_history.append({"role": "user", "content": user_msg})
@@ -400,9 +412,23 @@ def run_episode(
400
  # ---------------------------------------------------------------------------
401
 
402
 
 
 
 
 
 
 
 
 
403
  def main() -> None:
 
 
 
404
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
405
 
 
 
 
406
  all_tasks = {"easy": 42, "medium": 123, "hard": 7}
407
  task_filter = os.getenv("TASKS", "").strip()
408
  selected = [t.strip() for t in task_filter.split(",")] if task_filter else list(all_tasks)
@@ -418,6 +444,9 @@ def main() -> None:
418
 
419
  results = []
420
  for task_id, seed in tasks:
 
 
 
421
  print(f"--- Task: {task_id} (seed={seed}) ---", flush=True)
422
  result = run_episode(client, task_id, seed)
423
  results.append(result)
@@ -453,8 +482,10 @@ def main() -> None:
453
  out_file.write_text(json.dumps(payload, indent=2))
454
  print(f"\n Results saved -> {out_file.name}", flush=True)
455
 
 
456
  total = _token_usage["prompt"] + _token_usage["completion"]
457
- print(f"\n Token usage:", flush=True)
 
458
  print(f" prompt: {_token_usage['prompt']:,}", flush=True)
459
  print(f" completion: {_token_usage['completion']:,}", flush=True)
460
  print(f" total: {total:,}", flush=True)
 
117
  # ---------------------------------------------------------------------------
118
 
119
 
120
+ MAX_LLM_RETRIES = 3
121
+
122
+
123
  def _call_llm(messages: List[Dict[str, Any]], client: OpenAI) -> str:
124
+ """Call the LLM with bounded retry. Returns raw response text."""
125
+ for attempt in range(1, MAX_LLM_RETRIES + 1):
 
126
  try:
127
  completion = client.chat.completions.create(
128
  model=MODEL_NAME,
129
  messages=messages,
130
  temperature=0,
131
+ max_tokens=512,
132
+ timeout=30.0,
133
  )
134
  _track_usage(completion)
135
  return completion.choices[0].message.content or ""
136
  except Exception as e:
137
+ print(f" [attempt {attempt}/{MAX_LLM_RETRIES}] {MODEL_NAME} error: {e}", flush=True)
138
+ if attempt < MAX_LLM_RETRIES:
139
+ wait = min(5 * attempt, 15)
140
+ time.sleep(wait)
141
+ return '{"action_type": "noop", "params": {}}'
142
 
143
 
144
  # ---------------------------------------------------------------------------
 
239
  # ---------------------------------------------------------------------------
240
 
241
 
242
+ def _wait_for_server(base: str, max_wait: int = 30) -> None:
243
  """Poll /health until server is ready or timeout."""
244
+ import httpx
245
  deadline = time.time() + max_wait
246
  while time.time() < deadline:
247
  try:
248
  r = httpx.get(f"{base}/health", timeout=5.0)
249
  if r.status_code == 200:
250
+ print(f" Server ready at {base}", flush=True)
251
  return
252
  except Exception:
253
  pass
254
+ time.sleep(2)
255
+ print(f" [warn] Server not confirmed ready after {max_wait}s, proceeding anyway", flush=True)
256
 
257
 
258
  def run_episode(
 
264
 
265
  base = ENV_URL.rstrip("/")
266
 
 
 
 
267
  # Reset environment
268
+ try:
269
+ reset_resp = httpx.post(
270
+ f"{base}/reset",
271
+ json={"seed": seed, "task_id": task_id},
272
+ timeout=30.0,
273
+ )
274
+ resp_data = reset_resp.json()
275
+ except Exception as e:
276
+ print(f" [reset error] {e}", flush=True)
277
+ log_start(task=task_id, env=ENV_NAME, model=MODEL_NAME)
278
+ log_end(task=task_id, success=False, steps=0, score=0.0, rewards=[])
279
+ return {"task_id": task_id, "seed": seed, "score": 0.0, "slo_recovery": 0.0,
280
+ "action_efficiency": 0.0, "time_efficiency": 0.0, "steps_taken": 0,
281
+ "termination_reason": "reset_error", "rewards": []}
282
  obs = resp_data.get("observation", resp_data)
283
 
284
  max_steps = obs.get("max_steps", 10)
 
306
  for step_num in range(1, max_steps + 1):
307
  if done:
308
  break
309
+ if _time_remaining() < 30:
310
+ print(f" [timeout guard] Stopping episode at step {step_num} — {_time_remaining():.0f}s left", flush=True)
311
+ break
312
 
313
  user_msg = build_observation_prompt(obs)
314
  conversation_history.append({"role": "user", "content": user_msg})
 
412
  # ---------------------------------------------------------------------------
413
 
414
 
415
+ GLOBAL_TIMEOUT = 20 * 60 # 20 minutes hard cap (validator limit is 30 min)
416
+ _start_time: float = 0.0
417
+
418
+
419
+ def _time_remaining() -> float:
420
+ return max(0, GLOBAL_TIMEOUT - (time.time() - _start_time))
421
+
422
+
423
  def main() -> None:
424
+ global _start_time
425
+ _start_time = time.time()
426
+
427
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
428
 
429
+ base = ENV_URL.rstrip("/")
430
+ _wait_for_server(base)
431
+
432
  all_tasks = {"easy": 42, "medium": 123, "hard": 7}
433
  task_filter = os.getenv("TASKS", "").strip()
434
  selected = [t.strip() for t in task_filter.split(",")] if task_filter else list(all_tasks)
 
444
 
445
  results = []
446
  for task_id, seed in tasks:
447
+ if _time_remaining() < 60:
448
+ print(f" [timeout guard] Skipping {task_id} — only {_time_remaining():.0f}s left", flush=True)
449
+ break
450
  print(f"--- Task: {task_id} (seed={seed}) ---", flush=True)
451
  result = run_episode(client, task_id, seed)
452
  results.append(result)
 
482
  out_file.write_text(json.dumps(payload, indent=2))
483
  print(f"\n Results saved -> {out_file.name}", flush=True)
484
 
485
+ elapsed = time.time() - _start_time
486
  total = _token_usage["prompt"] + _token_usage["completion"]
487
+ print(f"\n Wall time: {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
488
+ print(f" Token usage:", flush=True)
489
  print(f" prompt: {_token_usage['prompt']:,}", flush=True)
490
  print(f" completion: {_token_usage['completion']:,}", flush=True)
491
  print(f" total: {total:,}", flush=True)