nxdev-org commited on
Commit
e0d7edf
·
1 Parent(s): fa124e7

update session manage and response complete verify

Browse files
Files changed (1) hide show
  1. api_server.py +186 -364
api_server.py CHANGED
@@ -1,29 +1,28 @@
1
  #!/usr/bin/env python3
2
  """
3
- api_server.py — Enterprise-grade OpenAI-Compatible MQTT Bridge
4
 
5
  Features:
6
- - Strict OpenAI API Spec compliance (Pydantic v2).
7
- - Proactive Worker Discovery (Broadcast Ping).
8
- - Robust Asyncio/Thread bridging.
9
- - Graceful handling of client disconnects and timeouts.
10
  """
11
 
12
  import json
13
  import time
14
  import uuid
15
  import asyncio
16
- import argparse
17
  import logging
18
  import os
19
- import threading
20
  from typing import Optional, List, Dict, Any, Union, Literal, AsyncGenerator
21
  from contextlib import asynccontextmanager
22
 
23
- from fastapi import FastAPI, HTTPException, Header, Request, status
 
24
  from fastapi.responses import StreamingResponse, JSONResponse
25
  from fastapi.middleware.cors import CORSMiddleware
26
- from pydantic import BaseModel, Field
27
  import paho.mqtt.client as mqtt
28
  from paho.mqtt.client import CallbackAPIVersion
29
  import uvicorn
@@ -35,316 +34,206 @@ class Config:
35
  BROKER_HOST = os.getenv("MQTT_BROKER_HOST", "nxdev-org-mqtt-broker.hf.space")
36
  BROKER_PORT = int(os.getenv("MQTT_BROKER_PORT", "443"))
37
  USE_TLS = os.getenv("MQTT_USE_TLS", "true").lower() in ("1", "true", "yes")
38
-
39
- # BROKER_HOST = os.getenv("MQTT_BROKER_HOST", "localhost")
40
- # BROKER_PORT = int(os.getenv("MQTT_BROKER_PORT", "7860"))
41
- # USE_TLS = os.getenv("MQTT_USE_TLS", "false").lower() in ("1", "true", "yes")
42
-
43
  WS_PATH = os.getenv("MQTT_WS_PATH", "/mqtt")
44
 
45
- API_HOST = os.getenv("API_HOST", "0.0.0.0")
46
- API_PORT = int(os.getenv("API_PORT", "8001"))
47
 
48
- TIMEOUT_SEC = 60.0 # Max time to wait for first token
49
- SESSION_EXPIRY = 120.0 # Seconds before a model is considered offline
50
 
51
  config = Config()
52
- logger = logging.getLogger("arena-api")
 
53
 
54
  # ============================================================
55
- # OPENAI PYDANTIC MODELS (Strict Spec)
56
  # ============================================================
57
 
58
  class ChatMessage(BaseModel):
59
  role: str
60
  content: str
61
- name: Optional[str] = None
62
 
63
  class ChatCompletionRequest(BaseModel):
64
- model: str = "auto"
65
  messages: List[ChatMessage]
66
- temperature: Optional[float] = 1.0
67
- top_p: Optional[float] = 1.0
68
- n: Optional[int] = 1
69
- stream: Optional[bool] = False
70
- stop: Optional[Union[str, List[str]]] = None
71
- max_tokens: Optional[int] = None
72
- presence_penalty: Optional[float] = 0.0
73
- frequency_penalty: Optional[float] = 0.0
74
- user: Optional[str] = None
75
-
76
- # -- Responses --
77
 
78
  class ChoiceDelta(BaseModel):
79
- role: Optional[str] = None
80
  content: Optional[str] = None
81
- reasoning_content: Optional[str] = None # DeepSeek/Thinking extension
82
-
83
- class ChoiceMessage(BaseModel):
84
- role: str = "assistant"
85
- content: Optional[str] = ""
86
  reasoning_content: Optional[str] = None
87
 
88
- class Choice(BaseModel):
89
- index: int
90
- message: ChoiceMessage
91
- finish_reason: Optional[str] = None
92
-
93
  class ChoiceChunk(BaseModel):
94
- index: int
95
  delta: ChoiceDelta
96
  finish_reason: Optional[str] = None
97
-
98
- class UsageInfo(BaseModel):
99
- prompt_tokens: int = 0
100
- completion_tokens: int = 0
101
- total_tokens: int = 0
102
-
103
- class ChatCompletionResponse(BaseModel):
104
- id: str
105
- object: Literal["chat.completion"] = "chat.completion"
106
- created: int
107
- model: str
108
- system_fingerprint: Optional[str] = "fp_mqtt_bridge"
109
- choices: List[Choice]
110
- usage: UsageInfo
111
 
112
  class ChatCompletionChunk(BaseModel):
113
  id: str
114
- object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
115
  created: int
116
  model: str
117
- system_fingerprint: Optional[str] = "fp_mqtt_bridge"
118
  choices: List[ChoiceChunk]
119
 
120
  # ============================================================
121
- # MQTT BRIDGE CORE
122
  # ============================================================
123
 
124
- class MQTTBridge:
125
  def __init__(self):
126
- self.client_id = f"api-srv-{uuid.uuid4().hex[:8]}"
127
- self.sessions: Dict[str, Dict] = {} # { session_id: { last_seen, host, status } }
128
-
129
- # Request routing: req_id -> asyncio.Queue
130
- self._response_queues: Dict[str, asyncio.Queue] = {}
131
  self._loop: Optional[asyncio.AbstractEventLoop] = None
132
 
133
- # Paho Client
134
- self.client = mqtt.Client(
135
  callback_api_version=CallbackAPIVersion.VERSION2,
136
  client_id=self.client_id,
137
- transport="websockets",
138
- protocol=mqtt.MQTTv311,
139
  )
140
-
141
  if config.USE_TLS:
142
- self.client.tls_set()
143
-
144
- self.client.ws_set_options(path=config.WS_PATH, headers={"Sec-WebSocket-Protocol": "mqtt"})
145
- self.client.on_connect = self._on_connect
146
- self.client.on_message = self._on_message
147
- self.client.on_disconnect = self._on_disconnect
148
 
149
  def set_loop(self, loop):
150
  self._loop = loop
151
 
152
- def connect(self):
153
- logger.info(f"🔌 Connecting to Broker: {config.BROKER_HOST}:{config.BROKER_PORT}")
154
- try:
155
- self.client.connect(config.BROKER_HOST, config.BROKER_PORT, keepalive=60)
156
- self.client.loop_start()
157
- except Exception as e:
158
- logger.error(f"❌ Connection failed: {e}")
159
- raise e
160
-
161
- def disconnect(self):
162
- self.client.loop_stop()
163
- self.client.disconnect()
164
-
165
- # --- MQTT Callbacks (Threaded) ---
166
-
167
  def _on_connect(self, client, userdata, flags, rc, props=None):
168
  if rc == 0:
169
- logger.info("✅ MQTT Connected. Subscribing to heartbeats...")
170
- client.subscribe("arena-ai/+/response") # Listen for all worker responses
171
- client.subscribe("arena-ai/global/heartbeat") # Listen for worker presence
172
-
173
- # PROACTIVE DISCOVERY: Tell all workers to announce themselves immediately
174
- self.client.publish("arena-ai/global/discovery", "ping", retain=False)
175
  else:
176
- logger.error(f"❌ MQTT Connect Failed. RC={rc}")
177
-
178
- def _on_disconnect(self, client, userdata, flags, rc, props=None):
179
- logger.warning(f"⚠️ MQTT Disconnected (RC={rc})")
180
 
181
  def _on_message(self, client, userdata, msg):
182
  try:
183
  topic = msg.topic
184
  payload = json.loads(msg.payload.decode())
185
 
186
- # 1. Heartbeat Handling
187
  if topic == "arena-ai/global/heartbeat":
188
- print(f"⚠️ Get heartbeat payload:{payload}")
189
  sid = payload.get("id")
190
  if sid:
191
- status = {
192
  "last_seen": time.time(),
193
- "host": payload.get("host", "unknown"),
194
  "status": payload.get("status", "ready"),
195
- "model": payload.get("model", sid)
196
  }
197
- print(f"⚠️ Update sessions sid:{sid}, status:{status}")
198
- self.sessions[sid] = status
199
  return
200
 
201
- # 2. Response Handling
202
- # Topic format: arena-ai/{session_id}/response
203
  if topic.endswith("/response"):
204
- req_id = payload.get("id")
205
- if req_id and req_id in self._response_queues:
206
- # Thread-safe put into asyncio queue
207
- if self._loop:
208
- self._loop.call_soon_threadsafe(
209
- self._response_queues[req_id].put_nowait, payload
210
- )
211
-
212
  except Exception as e:
213
- logger.error(f"Message processing error: {e}")
214
 
215
- # --- Async API Methods ---
216
-
217
- def get_active_models(self) -> List[Dict]:
218
- """Return list of models seen in the last SESSION_EXPIRY seconds."""
219
  now = time.time()
220
- active = []
221
- # Filter stale sessions
222
- stale_ids = []
223
- for sid, info in self.sessions.items():
224
- if now - info["last_seen"] > config.SESSION_EXPIRY:
225
- stale_ids.append(sid)
226
  else:
227
- active.append({
228
- "id": sid,
229
- "object": "model",
230
- "created": int(info["last_seen"]),
231
- "owned_by": info["host"]
232
- })
233
-
234
- # Cleanup stale
235
- for sid in stale_ids:
236
- del self.sessions[sid]
237
-
238
  return active
239
 
240
- async def send_chat_request(self, req: ChatCompletionRequest) -> AsyncGenerator[Dict, None]:
241
- """
242
- Orchestrates the request:
243
- 1. Selects Model
244
- 2. Publishes MQTT Request
245
- 3. Listens for MQTT Responses
246
- 4. Yields chunks (or full response)
247
- """
248
- # 1. Resolve Model
249
- target_session = self._resolve_session(req.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # 2. Prepare Request
252
- req_id = uuid.uuid4().hex
253
- response_queue = asyncio.Queue()
254
- self._response_queues[req_id] = response_queue
255
-
256
  mqtt_payload = {
257
  "id": req_id,
258
  "messages": [m.model_dump() for m in req.messages],
259
  "stream": req.stream,
260
- "temperature": req.temperature,
261
- # Pass other params if needed
262
  }
263
-
264
- topic = f"arena-ai/{target_session}/request"
265
- logger.info(f"📤 Routing request {req_id[:6]} -> {target_session}")
266
 
267
  try:
268
- # Publish (Non-blocking)
269
- self.client.publish(topic, json.dumps(mqtt_payload), qos=1)
270
-
271
- # 3. Wait for Responses
272
- start_time = time.time()
273
- first_packet = True
274
 
 
 
275
  while True:
276
- # Calculate Timeout
277
- elapsed = time.time() - start_time
278
- remaining = config.TIMEOUT_SEC - elapsed
279
 
280
- if remaining <= 0:
281
- raise HTTPException(status_code=504, detail="Worker timed out waiting for response")
282
-
283
- try:
284
- # Wait for next chunk
285
- raw_msg = await asyncio.wait_for(response_queue.get(), timeout=remaining)
286
- except asyncio.TimeoutError:
287
- raise HTTPException(status_code=504, detail="Worker timed out")
288
-
289
- # Check for worker-side errors
290
- if "error" in raw_msg:
291
- raise HTTPException(status_code=502, detail=f"Worker Error: {raw_msg['error']}")
292
-
293
- # Yield logic
294
- yield raw_msg
295
 
296
- # Check for done
297
- choices = raw_msg.get("choices", [])
298
  if choices and choices[0].get("finish_reason"):
 
 
 
 
 
 
 
 
 
 
299
  break
300
-
301
- # Standardize object type check
302
- if raw_msg.get("object") == "chat.completion":
303
- break
304
-
305
  finally:
306
- # Cleanup
307
- if req_id in self._response_queues:
308
- del self._response_queues[req_id]
309
-
310
- def _resolve_session(self, model_name: str) -> str:
311
- # 1. Exact Match (Session ID)
312
- if model_name in self.sessions:
313
- return model_name
314
-
315
- # 2. "auto" or empty -> Pick any active
316
- active = self.get_active_models()
317
- if not active:
318
- raise HTTPException(status_code=503, detail="No active workers found. Please open the UserScript.")
319
-
320
- if model_name in ["auto", "default", "gpt-3.5-turbo"]:
321
- return active[0]["id"]
322
-
323
- # 3. Search by partial name (if worker sends friendly model name)
324
- # (Simplified: just assuming model_name == session_id for now)
325
- raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found.")
326
 
327
  # ============================================================
328
- # FASTAPI APP
329
  # ============================================================
330
 
331
- bridge = MQTTBridge()
332
 
333
  @asynccontextmanager
334
  async def lifespan(app: FastAPI):
335
- # Startup
336
- print(f"🚀 Starting API Server on {config.API_HOST}:{config.API_PORT}")
337
- print(f"🔗 Bridging to MQTT: {config.BROKER_HOST}")
338
-
339
- bridge.set_loop(asyncio.get_running_loop())
340
- bridge.connect()
341
- logger.info("🚀 API Server Started")
342
  yield
343
- # Shutdown
344
- bridge.disconnect()
345
- logger.info("👋 API Server Stopped")
346
 
347
- app = FastAPI(title="MQTT OpenAI Bridge", version="2.0", lifespan=lifespan)
348
 
349
  app.add_middleware(
350
  CORSMiddleware,
@@ -352,160 +241,93 @@ app.add_middleware(
352
  allow_methods=["*"],
353
  allow_headers=["*"],
354
  )
355
- @app.get("/")
356
- async def root():
357
- return {"message": "AI MQTT Bridge" }
358
-
359
- # --- Exception Handler for OpenAI format ---
360
- @app.exception_handler(HTTPException)
361
- async def openai_exception_handler(request: Request, exc: HTTPException):
362
- return JSONResponse(
363
- status_code=exc.status_code,
364
- content={
365
- "error": {
366
- "message": exc.detail,
367
- "type": "api_error",
368
- "param": None,
369
- "code": exc.status_code
370
- }
371
- }
372
- )
373
 
374
- # ============================================================
375
- # ENDPOINTS
376
- # ============================================================
377
 
378
- @app.get("/v1/models")
379
- async def list_models():
380
- return {"object": "list", "data": bridge.get_active_models()}
381
 
382
- @app.get("/health")
383
- async def health():
384
- return {
385
- "status": "ok",
386
- "mqtt": bridge.client.is_connected(),
387
- "workers": len(bridge.sessions)
388
- }
389
 
390
- @app.post("/v1/chat/completions")
391
- async def chat_completions(req: ChatCompletionRequest, request: Request):
 
 
 
 
 
392
 
393
- # Generate authoritative ID for this interaction
394
- chat_id = f"chatcmpl-{uuid.uuid4().hex}"
395
- created_ts = int(time.time())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
- # --- Streaming Response ---
398
  if req.stream:
399
- async def generator():
400
  try:
401
- async for raw_chunk in bridge.send_chat_request(req):
402
  if await request.is_disconnected():
403
- logger.info("Client disconnected, stopping stream")
404
  break
405
 
406
- # Convert raw worker JSON to Pydantic strict format
407
- # Worker sends: { choices: [{ delta: { content: "x" } }] }
408
- # We need to ensure fields exist
409
-
410
- worker_choices = raw_chunk.get("choices", [])
411
- delta = {}
412
- finish_reason = None
413
 
414
- if worker_choices:
415
- delta = worker_choices[0].get("delta", {}) or worker_choices[0].get("message", {})
416
- finish_reason = worker_choices[0].get("finish_reason")
417
-
418
- chunk_resp = ChatCompletionChunk(
419
- id=chat_id,
420
- created=created_ts,
421
- model=req.model,
422
- choices=[
423
- ChoiceChunk(
424
- index=0,
425
- delta=ChoiceDelta(
426
- content=delta.get("content"),
427
- reasoning_content=delta.get("reasoning_content"),
428
- role=delta.get("role") if delta.get("role") else None
429
- ),
430
- finish_reason=finish_reason
431
- )
432
- ]
433
  )
434
-
435
- yield f"data: {chunk_resp.model_dump_json(exclude_none=True)}\n\n"
436
-
437
- if finish_reason:
438
- yield "data: [DONE]\n\n"
439
- break
440
-
441
- except HTTPException as e:
442
- # If streaming started, we can't easily change status code,
443
- # but we can send an error object in the stream
444
- err_payload = json.dumps({"error": {"message": e.detail, "code": e.status_code}})
445
- yield f"data: {err_payload}\n\n"
446
- except Exception as e:
447
- logger.error(f"Stream error: {e}")
448
 
449
- return StreamingResponse(generator(), media_type="text/event-stream")
450
-
451
- # --- Non-Streaming Response (Buffered) ---
 
 
 
 
 
452
  else:
453
  full_content = ""
454
  full_reasoning = ""
455
- finish_reason = "stop"
456
-
457
- async for raw_chunk in bridge.send_chat_request(req):
458
- if await request.is_disconnected():
459
- raise HTTPException(499, "Client Closed Request")
460
-
461
- # Handle both "stream-like" chunks and "full" responses from worker
462
- choices = raw_chunk.get("choices", [])
463
  if not choices: continue
464
-
465
- delta = choices[0].get("delta", {}) or choices[0].get("message", {})
466
-
467
- if "content" in delta and delta["content"]:
468
- full_content += delta["content"]
469
- if "reasoning_content" in delta and delta["reasoning_content"]:
470
- full_reasoning += delta["reasoning_content"]
471
-
472
- if choices[0].get("finish_reason"):
473
- finish_reason = choices[0].get("finish_reason")
474
-
475
- return ChatCompletionResponse(
476
- id=chat_id,
477
- created=created_ts,
478
- model=req.model,
479
- choices=[
480
- Choice(
481
- index=0,
482
- message=ChoiceMessage(
483
- role="assistant",
484
- content=full_content,
485
- reasoning_content=full_reasoning if full_reasoning else None
486
- ),
487
- finish_reason=finish_reason
488
- )
489
- ],
490
- usage=UsageInfo(
491
- completion_tokens=len(full_content) // 4, # Rough estimate
492
- prompt_tokens=len(str(req.messages)) // 4,
493
- total_tokens=0
494
- )
495
- )
496
 
497
- # ============================================================
498
- # MAIN
499
- # ============================================================
500
  if __name__ == "__main__":
501
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
502
-
503
- print(f"🚀 Starting API Server on {config.API_HOST}:{config.API_PORT}")
504
- print(f"🔗 Bridging to MQTT: {config.BROKER_HOST}")
505
-
506
- uvicorn.run(
507
- app,
508
- host=config.API_HOST,
509
- port=config.API_PORT,
510
- log_level="info"
511
- )
 
1
  #!/usr/bin/env python3
2
  """
3
+ api_server.py — High-Performance OpenAI-Compatible MQTT Proxy (Multi-Worker Edition)
4
 
5
  Features:
6
+ - Full OpenAI Chat Completion API support.
7
+ - Multi-Worker Discovery: Lists every active browser tab as a unique model.
8
+ - Intelligent Routing: Routes requests to specific workers or load-balances across ready ones.
9
+ - Enhanced Session Isolation: Handles per-tab worker sessions (Zen v9.5).
10
  """
11
 
12
  import json
13
  import time
14
  import uuid
15
  import asyncio
 
16
  import logging
17
  import os
 
18
  from typing import Optional, List, Dict, Any, Union, Literal, AsyncGenerator
19
  from contextlib import asynccontextmanager
20
 
21
+ from fastapi.responses import HTMLResponse
22
+ from fastapi import FastAPI, HTTPException, Request
23
  from fastapi.responses import StreamingResponse, JSONResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel
26
  import paho.mqtt.client as mqtt
27
  from paho.mqtt.client import CallbackAPIVersion
28
  import uvicorn
 
34
  BROKER_HOST = os.getenv("MQTT_BROKER_HOST", "nxdev-org-mqtt-broker.hf.space")
35
  BROKER_PORT = int(os.getenv("MQTT_BROKER_PORT", "443"))
36
  USE_TLS = os.getenv("MQTT_USE_TLS", "true").lower() in ("1", "true", "yes")
 
 
 
 
 
37
  WS_PATH = os.getenv("MQTT_WS_PATH", "/mqtt")
38
 
39
+ API_HOST = "0.0.0.0"
40
+ API_PORT = 8001
41
 
42
+ TIMEOUT_SEC = 120.0
43
+ SESSION_EXPIRY = 30.0 # Workers must heartbeat every 2s, 30s is generous
44
 
45
  config = Config()
46
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
47
+ logger = logging.getLogger("openai-proxy")
48
 
49
  # ============================================================
50
+ # MODELS
51
  # ============================================================
52
 
53
  class ChatMessage(BaseModel):
54
  role: str
55
  content: str
 
56
 
57
  class ChatCompletionRequest(BaseModel):
58
+ model: str
59
  messages: List[ChatMessage]
60
+ stream: bool = False
61
+ temperature: float = 1.0
 
 
 
 
 
 
 
 
 
62
 
63
  class ChoiceDelta(BaseModel):
 
64
  content: Optional[str] = None
 
 
 
 
 
65
  reasoning_content: Optional[str] = None
66
 
 
 
 
 
 
67
  class ChoiceChunk(BaseModel):
 
68
  delta: ChoiceDelta
69
  finish_reason: Optional[str] = None
70
+ index: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  class ChatCompletionChunk(BaseModel):
73
  id: str
74
+ object: str = "chat.completion.chunk"
75
  created: int
76
  model: str
 
77
  choices: List[ChoiceChunk]
78
 
79
  # ============================================================
80
+ # MQTT BRIDGE ENGINE
81
  # ============================================================
82
 
83
+ class OpenAIProxyEngine:
84
  def __init__(self):
85
+ self.client_id = f"proxy-{uuid.uuid4().hex[:8]}"
86
+ self.workers: Dict[str, Dict] = {} # sid -> {model, status, last_seen, host}
87
+ self._queues: Dict[str, asyncio.Queue] = {}
 
 
88
  self._loop: Optional[asyncio.AbstractEventLoop] = None
89
 
90
+ # Paho MQTT Setup
91
+ self.mqtt = mqtt.Client(
92
  callback_api_version=CallbackAPIVersion.VERSION2,
93
  client_id=self.client_id,
94
+ transport="websockets"
 
95
  )
 
96
  if config.USE_TLS:
97
+ self.mqtt.tls_set()
98
+ self.mqtt.ws_set_options(path=config.WS_PATH, headers={"Sec-WebSocket-Protocol": "mqtt"})
99
+
100
+ self.mqtt.on_connect = self._on_connect
101
+ self.mqtt.on_message = self._on_message
 
102
 
103
  def set_loop(self, loop):
104
  self._loop = loop
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def _on_connect(self, client, userdata, flags, rc, props=None):
107
  if rc == 0:
108
+ logger.info("✅ Proxy connected to MQTT broker")
109
+ client.subscribe("arena-ai/+/response")
110
+ client.subscribe("arena-ai/global/heartbeat")
111
+ client.publish("arena-ai/global/discovery", "ping")
 
 
112
  else:
113
+ logger.error(f"❌ MQTT Connection failed: {rc}")
 
 
 
114
 
115
  def _on_message(self, client, userdata, msg):
116
  try:
117
  topic = msg.topic
118
  payload = json.loads(msg.payload.decode())
119
 
 
120
  if topic == "arena-ai/global/heartbeat":
 
121
  sid = payload.get("id")
122
  if sid:
123
+ self.workers[sid] = {
124
  "last_seen": time.time(),
125
+ "model": payload.get("model", "AI-Worker"),
126
  "status": payload.get("status", "ready"),
127
+ "host": payload.get("host", "unknown")
128
  }
 
 
129
  return
130
 
 
 
131
  if topic.endswith("/response"):
132
+ rid = payload.get("id")
133
+ if rid in self._queues and self._loop:
134
+ self._loop.call_soon_threadsafe(self._queues[rid].put_nowait, payload)
 
 
 
 
 
135
  except Exception as e:
136
+ logger.error(f"Error processing MQTT message: {e}")
137
 
138
+ def get_active_workers(self):
 
 
 
139
  now = time.time()
140
+ active = {}
141
+ for sid, info in list(self.workers.items()):
142
+ if now - info["last_seen"] < config.SESSION_EXPIRY:
143
+ active[sid] = info
 
 
144
  else:
145
+ del self.workers[sid]
 
 
 
 
 
 
 
 
 
 
146
  return active
147
 
148
+ async def chat(self, req: ChatCompletionRequest) -> AsyncGenerator[Dict, None]:
149
+ active = self.get_active_workers()
150
+ target_sid = None
151
+
152
+ # 1. Try exact SID match
153
+ if req.model in active:
154
+ target_sid = req.model
155
+ # 2. Try "Model:SID" format match
156
+ elif ":" in req.model:
157
+ parts = req.model.split(":")
158
+ potential_sid = parts[-1]
159
+ if potential_sid in active:
160
+ target_sid = potential_sid
161
+
162
+ # 3. Fallback: Find worker by model name
163
+ if not target_sid:
164
+ candidates = [sid for sid, info in active.items() if info["model"] == req.model and info["status"] == "ready"]
165
+ if candidates:
166
+ target_sid = candidates[0]
167
+
168
+ # 4. Final Fallback: First ready worker
169
+ if not target_sid:
170
+ ready = [sid for sid, info in active.items() if info["status"] == "ready"]
171
+ if not ready:
172
+ logger.error("❌ No active Zen workers available")
173
+ raise HTTPException(status_code=503, detail="No active Zen Bridge workers found")
174
+ target_sid = ready[0]
175
+
176
+ req_id = f"req-{uuid.uuid4().hex[:12]}"
177
+ q = asyncio.Queue()
178
+ self._queues[req_id] = q
179
 
 
 
 
 
 
180
  mqtt_payload = {
181
  "id": req_id,
182
  "messages": [m.model_dump() for m in req.messages],
183
  "stream": req.stream,
184
+ "temperature": req.temperature
 
185
  }
186
+
187
+ logger.info(f"📤 [OpenAI] Start {req_id} -> Worker {target_sid} ({active[target_sid]['model']})")
 
188
 
189
  try:
190
+ self.mqtt.publish(f"arena-ai/{target_sid}/request", json.dumps(mqtt_payload), qos=1)
 
 
 
 
 
191
 
192
+ start = time.time()
193
+ chunk_count = 0
194
  while True:
195
+ if time.time() - start > config.TIMEOUT_SEC:
196
+ logger.warning(f"⏰ {req_id} timed out")
197
+ raise asyncio.TimeoutError()
198
 
199
+ chunk = await q.get()
200
+ chunk_count += 1
201
+ yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ choices = chunk.get("choices", [])
204
+ is_done = False
205
  if choices and choices[0].get("finish_reason"):
206
+ is_done = True
207
+ elif chunk.get("object") == "chat.completion":
208
+ is_done = True
209
+
210
+ if is_done:
211
+ while not q.empty():
212
+ extra = q.get_nowait()
213
+ yield extra
214
+ chunk_count += 1
215
+ logger.info(f"✅ [OpenAI] End {req_id} ({chunk_count} chunks)")
216
  break
 
 
 
 
 
217
  finally:
218
+ if req_id in self._queues:
219
+ del self._queues[req_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # ============================================================
222
+ # API SERVER
223
  # ============================================================
224
 
225
+ engine = OpenAIProxyEngine()
226
 
227
  @asynccontextmanager
228
  async def lifespan(app: FastAPI):
229
+ engine.set_loop(asyncio.get_running_loop())
230
+ engine.mqtt.connect(config.BROKER_HOST, config.BROKER_PORT)
231
+ engine.mqtt.loop_start()
 
 
 
 
232
  yield
233
+ engine.mqtt.loop_stop()
234
+ engine.mqtt.disconnect()
 
235
 
236
+ app = FastAPI(title="Zen OpenAI Proxy", lifespan=lifespan)
237
 
238
  app.add_middleware(
239
  CORSMiddleware,
 
241
  allow_methods=["*"],
242
  allow_headers=["*"],
243
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
 
245
 
246
+ @app.get("/", response_class=HTMLResponse)
247
+ async def index():
248
+ return "High-Performance OpenAI-Compatible MQTT Proxy (Multi-Worker Edition)"
249
 
 
 
 
 
 
 
 
250
 
251
+ @app.get("/v1/models")
252
+ async def models():
253
+ active = engine.get_active_workers()
254
+ data = []
255
+
256
+ # Add a generic "auto" model
257
+ data.append({"id": "auto", "object": "model", "owned_by": "zen-bridge"})
258
 
259
+ for sid, info in active.items():
260
+ # Represent each session as a model: "ModelName:SID"
261
+ model_id = f"{info['model']}:{sid}"
262
+ data.append({
263
+ "id": model_id,
264
+ "object": "model",
265
+ "owned_by": "zen-bridge",
266
+ "meta": {
267
+ "sid": sid,
268
+ "status": info["status"],
269
+ "host": info["host"]
270
+ }
271
+ })
272
+ return {"object": "list", "data": data}
273
+
274
+ @app.post("/v1/chat/completions")
275
+ async def chat(req: ChatCompletionRequest, request: Request):
276
+ chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
277
+ created = int(time.time())
278
 
 
279
  if req.stream:
280
+ async def stream_gen():
281
  try:
282
+ async for chunk in engine.chat(req):
283
  if await request.is_disconnected():
 
284
  break
285
 
286
+ choices = chunk.get("choices", [])
287
+ if not choices: continue
288
+ delta_data = choices[0].get("delta", {}) or choices[0].get("message", {})
 
 
 
 
289
 
290
+ resp = ChatCompletionChunk(
291
+ id=chat_id, created=created, model=req.model,
292
+ choices=[ChoiceChunk(
293
+ delta=ChoiceDelta(
294
+ content=delta_data.get("content"),
295
+ reasoning_content=delta_data.get("reasoning_content")
296
+ ),
297
+ finish_reason=choices[0].get("finish_reason")
298
+ )]
 
 
 
 
 
 
 
 
 
 
299
  )
300
+ yield f"data: {resp.model_dump_json(exclude_none=True)}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ if not await request.is_disconnected():
303
+ yield "data: [DONE]\n\n"
304
+ except Exception as e:
305
+ logger.error(f"Stream Error: {e}")
306
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
307
+
308
+ return StreamingResponse(stream_gen(), media_type="text/event-stream")
309
+
310
  else:
311
  full_content = ""
312
  full_reasoning = ""
313
+ async for chunk in engine.chat(req):
314
+ choices = chunk.get("choices", [])
 
 
 
 
 
 
315
  if not choices: continue
316
+ delta_data = choices[0].get("delta", {}) or choices[0].get("message", {})
317
+ full_content += delta_data.get("content", "") or ""
318
+ full_reasoning += delta_data.get("reasoning_content", "") or ""
319
+
320
+ return {
321
+ "id": chat_id, "object": "chat.completion", "created": created, "model": req.model,
322
+ "choices": [{
323
+ "message": {
324
+ "role": "assistant",
325
+ "content": full_content,
326
+ "reasoning_content": full_reasoning if full_reasoning else None
327
+ },
328
+ "finish_reason": "stop", "index": 0
329
+ }]
330
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
 
 
 
332
  if __name__ == "__main__":
333
+ uvicorn.run(app, host=config.API_HOST, port=config.API_PORT)