eeshwar143 commited on
Commit
3e6da75
·
1 Parent(s): e4accbb

Harden inference bootstrap and container startup

Browse files
Files changed (2) hide show
  1. inference.py +52 -27
  2. support_queue_env/client.py +88 -5
inference.py CHANGED
@@ -25,7 +25,6 @@ def log_start(task: str, env: str, model: str) -> None:
25
  print(f"[START] task={task} env={env} model={model}", flush=True)
26
 
27
 
28
-
29
  def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
30
  error_value = "none" if error is None else error.replace("\n", " ")
31
  print(
@@ -34,7 +33,6 @@ def log_step(step: int, action: str, reward: float, done: bool, error: str | Non
34
  )
35
 
36
 
37
-
38
  def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
39
  print(
40
  f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={json.dumps([round(r, 4) for r in rewards])}",
@@ -42,7 +40,6 @@ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> No
42
  )
43
 
44
 
45
-
46
  def get_model_message(
47
  client: OpenAI,
48
  step: int,
@@ -72,7 +69,6 @@ def get_model_message(
72
  return "hello"
73
 
74
 
75
-
76
  def available_tasks() -> list[TaskCard]:
77
  return [
78
  TaskCard(
@@ -86,7 +82,6 @@ def available_tasks() -> list[TaskCard]:
86
  ]
87
 
88
 
89
-
90
  def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction:
91
  text = " ".join(
92
  [
@@ -193,9 +188,7 @@ def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction
193
  )
194
 
195
 
196
- async def run_task(client: OpenAI, task: TaskCard) -> dict[str, Any]:
197
- env = await SupportQueueEnv.from_docker_image(LOCAL_IMAGE_NAME)
198
-
199
  history: List[str] = []
200
  rewards: List[float] = []
201
  steps_taken = 0
@@ -216,7 +209,13 @@ async def run_task(client: OpenAI, task: TaskCard) -> dict[str, Any]:
216
  _ = get_model_message(client, step, observation, last_reward, history)
217
  action = heuristic_action(observation)
218
 
219
- result = await env.step(action)
 
 
 
 
 
 
220
  reward = result.reward or 0.0
221
  done = result.done
222
  error = None
@@ -237,11 +236,10 @@ async def run_task(client: OpenAI, task: TaskCard) -> dict[str, Any]:
237
  score = min(max(score, 0.0), 1.0)
238
  success = score >= SUCCESS_SCORE_THRESHOLD
239
 
 
 
 
240
  finally:
241
- try:
242
- await env.close()
243
- except Exception as exc:
244
- print(f"[DEBUG] env.close() error (container cleanup): {exc}", flush=True)
245
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
246
 
247
  return {
@@ -254,21 +252,48 @@ async def run_task(client: OpenAI, task: TaskCard) -> dict[str, Any]:
254
 
255
 
256
  async def main() -> None:
257
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
258
- results = []
259
-
260
- for task in available_tasks():
261
- results.append(await run_task(client, task))
262
 
263
- aggregate = {
264
- "benchmark": BENCHMARK,
265
- "model": MODEL_NAME,
266
- "average_score": round(sum(item["score"] for item in results) / len(results), 4) if results else 0.0,
267
- "tasks": results,
268
- }
269
- with open("inference_results.json", "w", encoding="utf-8") as handle:
270
- json.dump(aggregate, handle, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
  if __name__ == "__main__":
274
- asyncio.run(main())
 
 
 
 
25
  print(f"[START] task={task} env={env} model={model}", flush=True)
26
 
27
 
 
28
  def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
29
  error_value = "none" if error is None else error.replace("\n", " ")
30
  print(
 
33
  )
34
 
35
 
 
36
  def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
37
  print(
38
  f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={json.dumps([round(r, 4) for r in rewards])}",
 
40
  )
41
 
42
 
 
43
  def get_model_message(
44
  client: OpenAI,
45
  step: int,
 
69
  return "hello"
70
 
71
 
 
72
  def available_tasks() -> list[TaskCard]:
73
  return [
74
  TaskCard(
 
82
  ]
83
 
84
 
 
85
  def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction:
86
  text = " ".join(
87
  [
 
188
  )
189
 
190
 
191
+ async def run_task(client: OpenAI, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]:
 
 
192
  history: List[str] = []
193
  rewards: List[float] = []
194
  steps_taken = 0
 
209
  _ = get_model_message(client, step, observation, last_reward, history)
210
  action = heuristic_action(observation)
211
 
212
+ try:
213
+ result = await env.step(action)
214
+ except Exception as exc:
215
+ action_payload = json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True)
216
+ log_step(step=step, action=action_payload, reward=0.0, done=True, error=str(exc))
217
+ break
218
+
219
  reward = result.reward or 0.0
220
  done = result.done
221
  error = None
 
236
  score = min(max(score, 0.0), 1.0)
237
  success = score >= SUCCESS_SCORE_THRESHOLD
238
 
239
+ except Exception as exc:
240
+ print(f"[DEBUG] Task {task.task_id} failed: {exc}", flush=True)
241
+
242
  finally:
 
 
 
 
243
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
244
 
245
  return {
 
252
 
253
 
254
  async def main() -> None:
255
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "placeholder")
256
+ tasks = available_tasks()
257
+ results: list[dict[str, Any]] = []
258
+ env: SupportQueueEnv | None = None
 
259
 
260
+ try:
261
+ env = await SupportQueueEnv.from_docker_image(LOCAL_IMAGE_NAME)
262
+ for task in tasks:
263
+ results.append(await run_task(client, env, task))
264
+ except Exception as exc:
265
+ print(f"[DEBUG] Environment bootstrap failed: {exc}", flush=True)
266
+ for task in tasks:
267
+ log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME)
268
+ log_end(success=False, steps=0, score=0.0, rewards=[])
269
+ results.append(
270
+ {
271
+ "task_id": task.task_id,
272
+ "score": 0.0,
273
+ "steps": 0,
274
+ "rewards": [],
275
+ "success": False,
276
+ }
277
+ )
278
+ finally:
279
+ if env is not None:
280
+ try:
281
+ await env.close()
282
+ except Exception as exc:
283
+ print(f"[DEBUG] env.close() error (container cleanup): {exc}", flush=True)
284
+
285
+ aggregate = {
286
+ "benchmark": BENCHMARK,
287
+ "model": MODEL_NAME,
288
+ "average_score": round(sum(item["score"] for item in results) / len(results), 4) if results else 0.0,
289
+ "tasks": results,
290
+ }
291
+ with open("inference_results.json", "w", encoding="utf-8") as handle:
292
+ json.dump(aggregate, handle, indent=2)
293
 
294
 
295
  if __name__ == "__main__":
296
+ try:
297
+ asyncio.run(main())
298
+ except Exception as exc:
299
+ print(f"[DEBUG] Fatal inference error: {exc}", flush=True)
support_queue_env/client.py CHANGED
@@ -4,13 +4,22 @@ from __future__ import annotations
4
 
5
  import asyncio
6
  import os
 
 
 
7
  from typing import Any
8
 
9
  import requests
10
 
11
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation, SupportQueueState
12
 
13
- DEFAULT_ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:8000")
 
 
 
 
 
 
14
 
15
 
16
  class _Result:
@@ -21,8 +30,9 @@ class _Result:
21
 
22
 
23
  class SupportQueueEnv:
24
- def __init__(self, base_url: str) -> None:
25
  self.base_url = base_url.rstrip("/")
 
26
 
27
  @classmethod
28
  def from_base_url(cls, base_url: str) -> "SupportQueueEnv":
@@ -30,8 +40,80 @@ class SupportQueueEnv:
30
 
31
  @classmethod
32
  async def from_docker_image(cls, image_name: str | None = None) -> "SupportQueueEnv":
33
- _ = image_name
34
- return cls(base_url=DEFAULT_ENV_BASE_URL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def list_tasks(self) -> list[TaskCard]:
37
  response = requests.get(f"{self.base_url}/tasks", timeout=30)
@@ -67,4 +149,5 @@ class SupportQueueEnv:
67
  return await asyncio.to_thread(self.state_sync)
68
 
69
  async def close(self) -> None:
70
- return None
 
 
4
 
5
  import asyncio
6
  import os
7
+ import socket
8
+ import subprocess
9
+ import time
10
  from typing import Any
11
 
12
  import requests
13
 
14
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation, SupportQueueState
15
 
16
+ DEFAULT_ENV_BASE_URL = os.getenv("ENV_BASE_URL")
17
+ DEFAULT_IMAGE_CANDIDATES = [
18
+ "support-queue-openenv:latest",
19
+ "support-queue-openenv",
20
+ "support_queue_env:latest",
21
+ "support_queue_env",
22
+ ]
23
 
24
 
25
  class _Result:
 
30
 
31
 
32
  class SupportQueueEnv:
33
+ def __init__(self, base_url: str, container_id: str | None = None) -> None:
34
  self.base_url = base_url.rstrip("/")
35
+ self.container_id = container_id
36
 
37
  @classmethod
38
  def from_base_url(cls, base_url: str) -> "SupportQueueEnv":
 
40
 
41
  @classmethod
42
  async def from_docker_image(cls, image_name: str | None = None) -> "SupportQueueEnv":
43
+ if DEFAULT_ENV_BASE_URL:
44
+ return cls(base_url=DEFAULT_ENV_BASE_URL)
45
+ return await asyncio.to_thread(cls._spawn_local_container, image_name)
46
+
47
+ @classmethod
48
+ def _spawn_local_container(cls, image_name: str | None) -> "SupportQueueEnv":
49
+ chosen_image = cls._resolve_image_name(image_name)
50
+ port = cls._pick_free_port()
51
+ container_id = cls._run(["docker", "run", "-d", "-p", f"{port}:8000", chosen_image]).strip()
52
+ base_url = f"http://127.0.0.1:{port}"
53
+
54
+ try:
55
+ cls._wait_until_ready(base_url)
56
+ except Exception:
57
+ cls._safe_remove_container(container_id)
58
+ raise
59
+
60
+ return cls(base_url=base_url, container_id=container_id)
61
+
62
+ @classmethod
63
+ def _resolve_image_name(cls, image_name: str | None) -> str:
64
+ candidates: list[str] = []
65
+ if image_name:
66
+ candidates.append(image_name)
67
+ candidates.extend(DEFAULT_IMAGE_CANDIDATES)
68
+
69
+ for candidate in candidates:
70
+ if cls._image_exists(candidate):
71
+ return candidate
72
+
73
+ build_tag = image_name or "support-queue-openenv:local"
74
+ cls._run(["docker", "build", "-t", build_tag, "."])
75
+ return build_tag
76
+
77
+ @staticmethod
78
+ def _image_exists(image_name: str) -> bool:
79
+ try:
80
+ SupportQueueEnv._run(["docker", "image", "inspect", image_name])
81
+ return True
82
+ except RuntimeError:
83
+ return False
84
+
85
+ @staticmethod
86
+ def _pick_free_port() -> int:
87
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
88
+ sock.bind(("127.0.0.1", 0))
89
+ return int(sock.getsockname()[1])
90
+
91
+ @staticmethod
92
+ def _wait_until_ready(base_url: str, timeout_seconds: int = 45) -> None:
93
+ deadline = time.time() + timeout_seconds
94
+ last_error = ""
95
+
96
+ while time.time() < deadline:
97
+ try:
98
+ response = requests.get(f"{base_url}/health", timeout=3)
99
+ if response.ok:
100
+ return
101
+ except Exception as exc:
102
+ last_error = str(exc)
103
+ time.sleep(1)
104
+
105
+ raise RuntimeError(f"Environment did not become ready at {base_url}: {last_error}")
106
+
107
+ @staticmethod
108
+ def _run(command: list[str]) -> str:
109
+ result = subprocess.run(command, check=False, capture_output=True, text=True)
110
+ if result.returncode != 0:
111
+ raise RuntimeError((result.stderr or result.stdout).strip() or f"Command failed: {' '.join(command)}")
112
+ return result.stdout
113
+
114
+ @staticmethod
115
+ def _safe_remove_container(container_id: str) -> None:
116
+ subprocess.run(["docker", "rm", "-f", container_id], check=False, capture_output=True, text=True)
117
 
118
  def list_tasks(self) -> list[TaskCard]:
119
  response = requests.get(f"{self.base_url}/tasks", timeout=30)
 
149
  return await asyncio.to_thread(self.state_sync)
150
 
151
  async def close(self) -> None:
152
+ if self.container_id:
153
+ await asyncio.to_thread(self._safe_remove_container, self.container_id)