SemSorter commited on
Commit
62b47e5
·
1 Parent(s): 5d14cd4

fix: thread safety, graceful shutdown, health endpoint, and dynamic obs fallback

Browse files
SemSorter/server/agent_bridge.py CHANGED
@@ -12,9 +12,12 @@ import asyncio
12
  import logging
13
  import os
14
  import sys
 
15
  from pathlib import Path
16
  from typing import Any, Callable, Dict, List, Optional
17
 
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
  # ── Path setup ────────────────────────────────────────────────────────────────
@@ -30,6 +33,8 @@ for _plugin in ("gemini", "deepgram", "elevenlabs", "getstream"):
30
  if _plugin_path.exists():
31
  sys.path.insert(0, str(_plugin_path))
32
 
 
 
33
  # ── Quota-tracking state ──────────────────────────────────────────────────────
34
  _quota_exceeded: Dict[str, bool] = {
35
  "gemini": False,
@@ -55,6 +60,7 @@ _bridge = None
55
  _llm = None
56
  _tts = None
57
  _notify_cb: Optional[Callable[[Dict], None]] = None # Push events to WebSocket
 
58
 
59
 
60
  def set_notify_callback(cb: Callable[[Dict], None]) -> None:
@@ -73,9 +79,10 @@ def _push(event: Dict) -> None:
73
 
74
 
75
  def _check_quota_error(exc: Exception) -> Optional[str]:
76
- """Return service name if the exception indicates API quota exhaustion."""
77
  msg = str(exc).lower()
78
- if "resource_exhausted" in msg or "429" in msg or "quota" in msg:
 
79
  if "gemini" in msg or "google" in msg:
80
  return "gemini"
81
  if "deepgram" in msg:
@@ -105,23 +112,25 @@ def _mark_quota_exceeded(service: str) -> None:
105
 
106
  def get_simulation():
107
  global _sim
108
- if _sim is None:
109
- os.environ.setdefault("MUJOCO_GL", "egl")
110
- from controller import SemSorterSimulation
111
- logger.info("Initialising MuJoCo simulation…")
112
- _sim = SemSorterSimulation()
113
- _sim.load_scene()
114
- _sim.step(300)
115
- logger.info("Simulation ready: %d items", len(_sim.items))
 
116
  return _sim
117
 
118
 
119
  def get_bridge():
120
  global _bridge
121
- if _bridge is None:
122
- from vlm_bridge import VLMSimBridge
123
- _bridge = VLMSimBridge(simulation=get_simulation(), use_direct=True)
124
- logger.info("VLM bridge ready")
 
125
  return _bridge
126
 
127
 
@@ -171,9 +180,7 @@ async def _scan_hazards_impl() -> Dict[str, Any]:
171
  try:
172
  bridge = get_bridge()
173
  loop = asyncio.get_event_loop()
174
- detections = await loop.run_in_executor(
175
- None, bridge.processor.detect_hazards)
176
- matched = bridge.match_detections_to_items(detections)
177
  return _format_scan(matched, demo=False)
178
  except Exception as exc:
179
  svc = _check_quota_error(exc)
@@ -216,14 +223,15 @@ async def _pick_place_impl(item_name: str, bin_type: str) -> Dict[str, Any]:
216
  return {"success": False, "error": f"{item_name} already sorted"}
217
 
218
  loop = asyncio.get_event_loop()
219
- success = await loop.run_in_executor(None, sim.pick_and_place, item_name, target)
220
  return {"success": success, "item": item_name, "bin": bin_type,
221
  "total_sorted": sim._items_sorted}
222
 
223
 
224
  def _state_impl() -> Dict[str, Any]:
225
- sim = get_simulation()
226
- state = sim.get_state()
 
227
  return {
228
  "time": round(state.time, 2),
229
  "arm_busy": state.arm_busy,
@@ -263,6 +271,31 @@ async def _sort_all_impl() -> Dict[str, Any]:
263
  "items_sorted": sorted_count, "details": details, "demo_mode": demo}
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  # ── Text → agent response ─────────────────────────────────────────────────────
267
 
268
  async def process_text_command(text: str) -> str:
@@ -288,6 +321,20 @@ async def process_text_command(text: str) -> str:
288
  return f"Error processing command: {exc}"
289
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  async def _llm_demo_response(text: str) -> str:
292
  """Return a plausible demo response when Gemini quota is exhausted."""
293
  t = text.lower()
 
12
  import logging
13
  import os
14
  import sys
15
+ import threading
16
  from pathlib import Path
17
  from typing import Any, Callable, Dict, List, Optional
18
 
19
+ from dotenv import load_dotenv
20
+
21
  logger = logging.getLogger(__name__)
22
 
23
  # ── Path setup ────────────────────────────────────────────────────────────────
 
33
  if _plugin_path.exists():
34
  sys.path.insert(0, str(_plugin_path))
35
 
36
+ load_dotenv(_PROJECT_ROOT / ".env")
37
+
38
  # ── Quota-tracking state ──────────────────────────────────────────────────────
39
  _quota_exceeded: Dict[str, bool] = {
40
  "gemini": False,
 
60
  _llm = None
61
  _tts = None
62
  _notify_cb: Optional[Callable[[Dict], None]] = None # Push events to WebSocket
63
+ _sim_lock = threading.RLock()
64
 
65
 
66
  def set_notify_callback(cb: Callable[[Dict], None]) -> None:
 
79
 
80
 
81
  def _check_quota_error(exc: Exception) -> Optional[str]:
82
+ """Return service name if the exception indicates quota/auth API failures."""
83
  msg = str(exc).lower()
84
+ if ("resource_exhausted" in msg or "429" in msg or "quota" in msg
85
+ or "invalid api key" in msg or "unauthorized" in msg or "401" in msg):
86
  if "gemini" in msg or "google" in msg:
87
  return "gemini"
88
  if "deepgram" in msg:
 
112
 
113
  def get_simulation():
114
  global _sim
115
+ with _sim_lock:
116
+ if _sim is None:
117
+ os.environ.setdefault("MUJOCO_GL", "egl")
118
+ from controller import SemSorterSimulation
119
+ logger.info("Initialising MuJoCo simulation…")
120
+ _sim = SemSorterSimulation()
121
+ _sim.load_scene()
122
+ _sim.step(300)
123
+ logger.info("Simulation ready: %d items", len(_sim.items))
124
  return _sim
125
 
126
 
127
  def get_bridge():
128
  global _bridge
129
+ with _sim_lock:
130
+ if _bridge is None:
131
+ from vlm_bridge import VLMSimBridge
132
+ _bridge = VLMSimBridge(simulation=get_simulation(), use_direct=True)
133
+ logger.info("VLM bridge ready")
134
  return _bridge
135
 
136
 
 
180
  try:
181
  bridge = get_bridge()
182
  loop = asyncio.get_event_loop()
183
+ detections, matched = await loop.run_in_executor(None, _detect_and_match_impl)
 
 
184
  return _format_scan(matched, demo=False)
185
  except Exception as exc:
186
  svc = _check_quota_error(exc)
 
223
  return {"success": False, "error": f"{item_name} already sorted"}
224
 
225
  loop = asyncio.get_event_loop()
226
+ success = await loop.run_in_executor(None, _pick_place_sync, sim, item_name, target)
227
  return {"success": success, "item": item_name, "bin": bin_type,
228
  "total_sorted": sim._items_sorted}
229
 
230
 
231
  def _state_impl() -> Dict[str, Any]:
232
+ with _sim_lock:
233
+ sim = get_simulation()
234
+ state = sim.get_state()
235
  return {
236
  "time": round(state.time, 2),
237
  "arm_busy": state.arm_busy,
 
271
  "items_sorted": sorted_count, "details": details, "demo_mode": demo}
272
 
273
 
274
+ def render_frame(camera: str = "overview"):
275
+ """Thread-safe simulation frame render for the video WS endpoint."""
276
+ with _sim_lock:
277
+ sim = get_simulation()
278
+ return sim.render_frame(camera=camera)
279
+
280
+
281
+ def close_resources() -> None:
282
+ """Best-effort shutdown for long-running server process."""
283
+ global _bridge, _sim
284
+ with _sim_lock:
285
+ if _bridge is not None:
286
+ try:
287
+ _bridge.close()
288
+ except Exception:
289
+ pass
290
+ _bridge = None
291
+ if _sim is not None and hasattr(_sim, "close"):
292
+ try:
293
+ _sim.close()
294
+ except Exception:
295
+ pass
296
+ _sim = None
297
+
298
+
299
  # ── Text → agent response ─────────────────────────────────────────────────────
300
 
301
  async def process_text_command(text: str) -> str:
 
321
  return f"Error processing command: {exc}"
322
 
323
 
324
+ def _detect_and_match_impl():
325
+ """Run detect+match atomically to avoid simulation/render race conditions."""
326
+ with _sim_lock:
327
+ bridge = get_bridge()
328
+ detections = bridge.processor.detect_hazards()
329
+ matched = bridge.match_detections_to_items(detections)
330
+ return detections, matched
331
+
332
+
333
+ def _pick_place_sync(sim, item_name: str, target) -> bool:
334
+ with _sim_lock:
335
+ return sim.pick_and_place(item_name, target)
336
+
337
+
338
  async def _llm_demo_response(text: str) -> str:
339
  """Return a plausible demo response when Gemini quota is exhausted."""
340
  t = text.lower()
SemSorter/server/app.py CHANGED
@@ -24,12 +24,11 @@ import json
24
  import logging
25
  import os
26
  from pathlib import Path
27
- from typing import Set
28
 
29
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
30
  from fastapi.responses import HTMLResponse, JSONResponse
31
  from fastapi.staticfiles import StaticFiles
32
- import numpy as np
33
  from PIL import Image
34
 
35
  # ── Local imports ─────────────────────────────────────────────────────────────
@@ -48,10 +47,12 @@ _STATIC.mkdir(exist_ok=True)
48
  # ── Connected WebSocket clients ───────────────────────────────────────────────
49
  _chat_clients: Set[WebSocket] = set()
50
  _video_clients: Set[WebSocket] = set()
 
51
 
52
 
53
  async def _broadcast_chat(event: dict) -> None:
54
  """Push a JSON event to all connected chat WebSocket clients."""
 
55
  payload = json.dumps(event)
56
  dead = set()
57
  for ws in list(_chat_clients):
@@ -59,17 +60,20 @@ async def _broadcast_chat(event: dict) -> None:
59
  await ws.send_text(payload)
60
  except Exception:
61
  dead.add(ws)
62
- _chat_clients -= dead
 
63
 
64
 
65
  def _sync_broadcast(event: dict) -> None:
66
  """Thread-safe push called from sync code (bridge callbacks)."""
 
 
67
  try:
68
- loop = asyncio.get_event_loop()
69
- if loop.is_running():
70
- asyncio.create_task(_broadcast_chat(event))
71
  except Exception:
72
- pass
73
 
74
 
75
  # Register the broadcast callback so agent_bridge can push quota warnings
@@ -79,12 +83,21 @@ bridge.set_notify_callback(_sync_broadcast)
79
  # ── Startup: pre-warm simulation ──────────────────────────────────────────────
80
  @app.on_event("startup")
81
  async def startup():
 
 
82
  logger.info("Pre-warming MuJoCo simulation…")
83
- loop = asyncio.get_event_loop()
84
- await loop.run_in_executor(None, bridge.get_simulation)
85
  logger.info("Simulation ready")
86
 
87
 
 
 
 
 
 
 
 
 
88
  # ── REST endpoints ────────────────────────────────────────────────────────────
89
 
90
  @app.get("/", response_class=HTMLResponse)
@@ -100,6 +113,11 @@ async def api_state():
100
  return JSONResponse(state)
101
 
102
 
 
 
 
 
 
103
  @app.post("/api/sort")
104
  async def api_sort():
105
  """Trigger the full detect-match-sort pipeline."""
@@ -178,8 +196,7 @@ async def ws_chat(ws: WebSocket):
178
 
179
  def _render_frame_jpeg(quality: int = 75) -> bytes:
180
  """Render a MuJoCo frame and encode as JPEG bytes."""
181
- sim = bridge.get_simulation()
182
- frame = sim.render_frame(camera="overview") # numpy H×W×3
183
  img = Image.fromarray(frame)
184
  buf = io.BytesIO()
185
  img.save(buf, format="JPEG", quality=quality)
 
24
  import logging
25
  import os
26
  from pathlib import Path
27
+ from typing import Optional, Set
28
 
29
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
30
  from fastapi.responses import HTMLResponse, JSONResponse
31
  from fastapi.staticfiles import StaticFiles
 
32
  from PIL import Image
33
 
34
  # ── Local imports ─────────────────────────────────────────────────────────────
 
47
  # ── Connected WebSocket clients ───────────────────────────────────────────────
48
  _chat_clients: Set[WebSocket] = set()
49
  _video_clients: Set[WebSocket] = set()
50
+ _main_loop: Optional[asyncio.AbstractEventLoop] = None
51
 
52
 
53
  async def _broadcast_chat(event: dict) -> None:
54
  """Push a JSON event to all connected chat WebSocket clients."""
55
+ global _chat_clients
56
  payload = json.dumps(event)
57
  dead = set()
58
  for ws in list(_chat_clients):
 
60
  await ws.send_text(payload)
61
  except Exception:
62
  dead.add(ws)
63
+ for ws in dead:
64
+ _chat_clients.discard(ws)
65
 
66
 
67
  def _sync_broadcast(event: dict) -> None:
68
  """Thread-safe push called from sync code (bridge callbacks)."""
69
+ if _main_loop is None:
70
+ return
71
  try:
72
+ _main_loop.call_soon_threadsafe(
73
+ asyncio.create_task, _broadcast_chat(event)
74
+ )
75
  except Exception:
76
+ logger.exception("Failed to schedule chat broadcast")
77
 
78
 
79
  # Register the broadcast callback so agent_bridge can push quota warnings
 
83
  # ── Startup: pre-warm simulation ──────────────────────────────────────────────
84
  @app.on_event("startup")
85
  async def startup():
86
+ global _main_loop
87
+ _main_loop = asyncio.get_running_loop()
88
  logger.info("Pre-warming MuJoCo simulation…")
89
+ await _main_loop.run_in_executor(None, bridge.get_simulation)
 
90
  logger.info("Simulation ready")
91
 
92
 
93
+ @app.on_event("shutdown")
94
+ async def shutdown():
95
+ logger.info("Shutting down SemSorter resources…")
96
+ loop = asyncio.get_running_loop()
97
+ await loop.run_in_executor(None, bridge.close_resources)
98
+ logger.info("Shutdown complete")
99
+
100
+
101
  # ── REST endpoints ────────────────────────────────────────────────────────────
102
 
103
  @app.get("/", response_class=HTMLResponse)
 
113
  return JSONResponse(state)
114
 
115
 
116
+ @app.get("/health")
117
+ async def health():
118
+ return JSONResponse({"ok": True})
119
+
120
+
121
  @app.post("/api/sort")
122
  async def api_sort():
123
  """Trigger the full detect-match-sort pipeline."""
 
196
 
197
  def _render_frame_jpeg(quality: int = 75) -> bytes:
198
  """Render a MuJoCo frame and encode as JPEG bytes."""
199
+ frame = bridge.render_frame(camera="overview") # numpy H×W×3
 
200
  img = Image.fromarray(frame)
201
  buf = io.BytesIO()
202
  img.save(buf, format="JPEG", quality=quality)
SemSorter/vision/vision_pipeline.py CHANGED
@@ -14,13 +14,11 @@ Usage:
14
 
15
  import os
16
  import sys
17
- import cv2
18
  import json
19
- import time
20
  import logging
21
  import google.generativeai as genai
22
  from PIL import Image
23
- from typing import List, Dict, Optional
24
 
25
  logger = logging.getLogger(__name__)
26
 
@@ -98,6 +96,14 @@ class HazardDetectionProcessor:
98
 
99
  def _capture_from_obs(self) -> Image.Image:
100
  """Capture a frame from the OBS Virtual Camera."""
 
 
 
 
 
 
 
 
101
  if self._video_cap is None or not self._video_cap.isOpened():
102
  self._video_cap = cv2.VideoCapture(self.device_id)
103
  if not self._video_cap.isOpened():
 
14
 
15
  import os
16
  import sys
 
17
  import json
 
18
  import logging
19
  import google.generativeai as genai
20
  from PIL import Image
21
+ from typing import List, Dict
22
 
23
  logger = logging.getLogger(__name__)
24
 
 
96
 
97
  def _capture_from_obs(self) -> Image.Image:
98
  """Capture a frame from the OBS Virtual Camera."""
99
+ try:
100
+ import cv2
101
+ except ImportError as exc:
102
+ raise RuntimeError(
103
+ "OpenCV is required for OBS capture mode. "
104
+ "Install opencv-python or opencv-python-headless."
105
+ ) from exc
106
+
107
  if self._video_cap is None or not self._video_cap.isOpened():
108
  self._video_cap = cv2.VideoCapture(self.device_id)
109
  if not self._video_cap.isOpened():
render.yaml CHANGED
@@ -17,5 +17,5 @@ services:
17
  sync: false
18
  - key: STREAM_API_SECRET
19
  sync: false
20
- healthCheckPath: /api/state
21
  autoDeploy: true
 
17
  sync: false
18
  - key: STREAM_API_SECRET
19
  sync: false
20
+ healthCheckPath: /health
21
  autoDeploy: true