Fix Phase 2 timeout: cap LLM retries, move server wait to main, add global time guard
Browse files- inference.py +54 -23
inference.py
CHANGED
|
@@ -117,25 +117,28 @@ def _track_usage(completion: Any) -> None:
|
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
def _call_llm(messages: List[Dict[str, Any]], client: OpenAI) -> str:
|
| 121 |
-
"""Call the LLM with
|
| 122 |
-
attempt
|
| 123 |
-
while True:
|
| 124 |
try:
|
| 125 |
completion = client.chat.completions.create(
|
| 126 |
model=MODEL_NAME,
|
| 127 |
messages=messages,
|
| 128 |
temperature=0,
|
| 129 |
-
max_tokens=
|
|
|
|
| 130 |
)
|
| 131 |
_track_usage(completion)
|
| 132 |
return completion.choices[0].message.content or ""
|
| 133 |
except Exception as e:
|
| 134 |
-
attempt
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
|
| 141 |
# ---------------------------------------------------------------------------
|
|
@@ -236,19 +239,20 @@ def parse_action(response_text: str) -> Dict[str, Any]:
|
|
| 236 |
# ---------------------------------------------------------------------------
|
| 237 |
|
| 238 |
|
| 239 |
-
def _wait_for_server(base: str, max_wait: int =
|
| 240 |
"""Poll /health until server is ready or timeout."""
|
| 241 |
-
import httpx
|
| 242 |
deadline = time.time() + max_wait
|
| 243 |
while time.time() < deadline:
|
| 244 |
try:
|
| 245 |
r = httpx.get(f"{base}/health", timeout=5.0)
|
| 246 |
if r.status_code == 200:
|
|
|
|
| 247 |
return
|
| 248 |
except Exception:
|
| 249 |
pass
|
| 250 |
-
time.sleep(
|
| 251 |
-
|
| 252 |
|
| 253 |
|
| 254 |
def run_episode(
|
|
@@ -260,16 +264,21 @@ def run_episode(
|
|
| 260 |
|
| 261 |
base = ENV_URL.rstrip("/")
|
| 262 |
|
| 263 |
-
# Wait for server to be ready (handles startup race condition)
|
| 264 |
-
_wait_for_server(base)
|
| 265 |
-
|
| 266 |
# Reset environment
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
obs = resp_data.get("observation", resp_data)
|
| 274 |
|
| 275 |
max_steps = obs.get("max_steps", 10)
|
|
@@ -297,6 +306,9 @@ def run_episode(
|
|
| 297 |
for step_num in range(1, max_steps + 1):
|
| 298 |
if done:
|
| 299 |
break
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
user_msg = build_observation_prompt(obs)
|
| 302 |
conversation_history.append({"role": "user", "content": user_msg})
|
|
@@ -400,9 +412,23 @@ def run_episode(
|
|
| 400 |
# ---------------------------------------------------------------------------
|
| 401 |
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
| 404 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 405 |
|
|
|
|
|
|
|
|
|
|
| 406 |
all_tasks = {"easy": 42, "medium": 123, "hard": 7}
|
| 407 |
task_filter = os.getenv("TASKS", "").strip()
|
| 408 |
selected = [t.strip() for t in task_filter.split(",")] if task_filter else list(all_tasks)
|
|
@@ -418,6 +444,9 @@ def main() -> None:
|
|
| 418 |
|
| 419 |
results = []
|
| 420 |
for task_id, seed in tasks:
|
|
|
|
|
|
|
|
|
|
| 421 |
print(f"--- Task: {task_id} (seed={seed}) ---", flush=True)
|
| 422 |
result = run_episode(client, task_id, seed)
|
| 423 |
results.append(result)
|
|
@@ -453,8 +482,10 @@ def main() -> None:
|
|
| 453 |
out_file.write_text(json.dumps(payload, indent=2))
|
| 454 |
print(f"\n Results saved -> {out_file.name}", flush=True)
|
| 455 |
|
|
|
|
| 456 |
total = _token_usage["prompt"] + _token_usage["completion"]
|
| 457 |
-
print(f"\n
|
|
|
|
| 458 |
print(f" prompt: {_token_usage['prompt']:,}", flush=True)
|
| 459 |
print(f" completion: {_token_usage['completion']:,}", flush=True)
|
| 460 |
print(f" total: {total:,}", flush=True)
|
|
|
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
|
| 119 |
|
| 120 |
+
MAX_LLM_RETRIES = 3
|
| 121 |
+
|
| 122 |
+
|
| 123 |
def _call_llm(messages: List[Dict[str, Any]], client: OpenAI) -> str:
|
| 124 |
+
"""Call the LLM with bounded retry. Returns raw response text."""
|
| 125 |
+
for attempt in range(1, MAX_LLM_RETRIES + 1):
|
|
|
|
| 126 |
try:
|
| 127 |
completion = client.chat.completions.create(
|
| 128 |
model=MODEL_NAME,
|
| 129 |
messages=messages,
|
| 130 |
temperature=0,
|
| 131 |
+
max_tokens=512,
|
| 132 |
+
timeout=30.0,
|
| 133 |
)
|
| 134 |
_track_usage(completion)
|
| 135 |
return completion.choices[0].message.content or ""
|
| 136 |
except Exception as e:
|
| 137 |
+
print(f" [attempt {attempt}/{MAX_LLM_RETRIES}] {MODEL_NAME} error: {e}", flush=True)
|
| 138 |
+
if attempt < MAX_LLM_RETRIES:
|
| 139 |
+
wait = min(5 * attempt, 15)
|
| 140 |
+
time.sleep(wait)
|
| 141 |
+
return '{"action_type": "noop", "params": {}}'
|
| 142 |
|
| 143 |
|
| 144 |
# ---------------------------------------------------------------------------
|
|
|
|
| 239 |
# ---------------------------------------------------------------------------
|
| 240 |
|
| 241 |
|
| 242 |
+
def _wait_for_server(base: str, max_wait: int = 30) -> None:
|
| 243 |
"""Poll /health until server is ready or timeout."""
|
| 244 |
+
import httpx
|
| 245 |
deadline = time.time() + max_wait
|
| 246 |
while time.time() < deadline:
|
| 247 |
try:
|
| 248 |
r = httpx.get(f"{base}/health", timeout=5.0)
|
| 249 |
if r.status_code == 200:
|
| 250 |
+
print(f" Server ready at {base}", flush=True)
|
| 251 |
return
|
| 252 |
except Exception:
|
| 253 |
pass
|
| 254 |
+
time.sleep(2)
|
| 255 |
+
print(f" [warn] Server not confirmed ready after {max_wait}s, proceeding anyway", flush=True)
|
| 256 |
|
| 257 |
|
| 258 |
def run_episode(
|
|
|
|
| 264 |
|
| 265 |
base = ENV_URL.rstrip("/")
|
| 266 |
|
|
|
|
|
|
|
|
|
|
| 267 |
# Reset environment
|
| 268 |
+
try:
|
| 269 |
+
reset_resp = httpx.post(
|
| 270 |
+
f"{base}/reset",
|
| 271 |
+
json={"seed": seed, "task_id": task_id},
|
| 272 |
+
timeout=30.0,
|
| 273 |
+
)
|
| 274 |
+
resp_data = reset_resp.json()
|
| 275 |
+
except Exception as e:
|
| 276 |
+
print(f" [reset error] {e}", flush=True)
|
| 277 |
+
log_start(task=task_id, env=ENV_NAME, model=MODEL_NAME)
|
| 278 |
+
log_end(task=task_id, success=False, steps=0, score=0.0, rewards=[])
|
| 279 |
+
return {"task_id": task_id, "seed": seed, "score": 0.0, "slo_recovery": 0.0,
|
| 280 |
+
"action_efficiency": 0.0, "time_efficiency": 0.0, "steps_taken": 0,
|
| 281 |
+
"termination_reason": "reset_error", "rewards": []}
|
| 282 |
obs = resp_data.get("observation", resp_data)
|
| 283 |
|
| 284 |
max_steps = obs.get("max_steps", 10)
|
|
|
|
| 306 |
for step_num in range(1, max_steps + 1):
|
| 307 |
if done:
|
| 308 |
break
|
| 309 |
+
if _time_remaining() < 30:
|
| 310 |
+
print(f" [timeout guard] Stopping episode at step {step_num} — {_time_remaining():.0f}s left", flush=True)
|
| 311 |
+
break
|
| 312 |
|
| 313 |
user_msg = build_observation_prompt(obs)
|
| 314 |
conversation_history.append({"role": "user", "content": user_msg})
|
|
|
|
| 412 |
# ---------------------------------------------------------------------------
|
| 413 |
|
| 414 |
|
| 415 |
+
GLOBAL_TIMEOUT = 20 * 60 # 20 minutes hard cap (validator limit is 30 min)
|
| 416 |
+
_start_time: float = 0.0
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _time_remaining() -> float:
|
| 420 |
+
return max(0, GLOBAL_TIMEOUT - (time.time() - _start_time))
|
| 421 |
+
|
| 422 |
+
|
| 423 |
def main() -> None:
|
| 424 |
+
global _start_time
|
| 425 |
+
_start_time = time.time()
|
| 426 |
+
|
| 427 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 428 |
|
| 429 |
+
base = ENV_URL.rstrip("/")
|
| 430 |
+
_wait_for_server(base)
|
| 431 |
+
|
| 432 |
all_tasks = {"easy": 42, "medium": 123, "hard": 7}
|
| 433 |
task_filter = os.getenv("TASKS", "").strip()
|
| 434 |
selected = [t.strip() for t in task_filter.split(",")] if task_filter else list(all_tasks)
|
|
|
|
| 444 |
|
| 445 |
results = []
|
| 446 |
for task_id, seed in tasks:
|
| 447 |
+
if _time_remaining() < 60:
|
| 448 |
+
print(f" [timeout guard] Skipping {task_id} — only {_time_remaining():.0f}s left", flush=True)
|
| 449 |
+
break
|
| 450 |
print(f"--- Task: {task_id} (seed={seed}) ---", flush=True)
|
| 451 |
result = run_episode(client, task_id, seed)
|
| 452 |
results.append(result)
|
|
|
|
| 482 |
out_file.write_text(json.dumps(payload, indent=2))
|
| 483 |
print(f"\n Results saved -> {out_file.name}", flush=True)
|
| 484 |
|
| 485 |
+
elapsed = time.time() - _start_time
|
| 486 |
total = _token_usage["prompt"] + _token_usage["completion"]
|
| 487 |
+
print(f"\n Wall time: {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
|
| 488 |
+
print(f" Token usage:", flush=True)
|
| 489 |
print(f" prompt: {_token_usage['prompt']:,}", flush=True)
|
| 490 |
print(f" completion: {_token_usage['completion']:,}", flush=True)
|
| 491 |
print(f" total: {total:,}", flush=True)
|