#!/usr/bin/env python3 """ api_server.py — OpenAI-Compatible MQTT Proxy v3.0 • Modern chat UI with markdown, code highlighting, thinking blocks • Admin debug console with live connection / worker / stats monitoring • Pressure-aware load balancing across browser-tab workers • Robust SSE streaming with proper chunk buffering """ import json, time, uuid, asyncio, logging, os from collections import deque from typing import Optional, List, Dict, AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import paho.mqtt.client as mqtt from paho.mqtt.client import CallbackAPIVersion import uvicorn # ════════════════════════════════════════════════════════════ # CONFIGURATION # ════════════════════════════════════════════════════════════ class Config: BROKER_HOST = os.getenv("MQTT_BROKER_HOST", os.getenv("BROKER_HOST", "127.0.0.1")) BROKER_PORT = int(os.getenv("MQTT_BROKER_PORT", os.getenv("BROKER_PORT", "1883"))) USE_TLS = os.getenv("MQTT_USE_TLS", "false").lower() in ("1", "true", "yes") WS_PATH = os.getenv("MQTT_WS_PATH", "/mqtt") WS_TRANSPORT = os.getenv("MQTT_TRANSPORT", "websockets" if int(os.getenv("MQTT_BROKER_PORT", os.getenv("BROKER_PORT", "1883"))) in [80, 443, 7860] else "tcp") API_HOST = "0.0.0.0" API_PORT = 8001 TIMEOUT_SEC = 180.0 SESSION_EXPIRY = 45.0 DEBUG_MODE = os.getenv("DEBUG_MODE", "false").lower() in ("1", "true", "yes") config = Config() logging.basicConfig( level=logging.DEBUG if config.DEBUG_MODE else logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) logger = logging.getLogger("zen-proxy") # ════════════════════════════════════════════════════════════ # PYDANTIC MODELS # ════════════════════════════════════════════════════════════ class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] stream: bool = False temperature: float = 1.0 class ChoiceDelta(BaseModel): content: Optional[str] = None reasoning_content: Optional[str] = None class ChoiceChunk(BaseModel): delta: ChoiceDelta finish_reason: Optional[str] = None index: int = 0 class ChatCompletionChunk(BaseModel): id: str object: str = "chat.completion.chunk" created: int model: str choices: List[ChoiceChunk] # ════════════════════════════════════════════════════════════ # MQTT PROXY ENGINE # ════════════════════════════════════════════════════════════ class ProxyEngine: def __init__(self): self.client_id = f"proxy-{uuid.uuid4().hex[:8]}" self.workers: Dict[str, Dict] = {} self._queues: Dict[str, asyncio.Queue] = {} self._loop: Optional[asyncio.AbstractEventLoop] = None self.connected = False self.activity_log: deque = deque(maxlen=200) self.stats = dict( start_time=time.time(), total_requests=0, active_streams=0, completed=0, failed=0, total_chunks=0, heartbeats_rx=0, ) self.mqtt = mqtt.Client( callback_api_version=CallbackAPIVersion.VERSION2, client_id=self.client_id, transport=config.WS_TRANSPORT, ) if config.USE_TLS: self.mqtt.tls_set() if config.WS_TRANSPORT == "websockets": self.mqtt.ws_set_options( path=config.WS_PATH, headers={"Sec-WebSocket-Protocol": "mqtt"}, ) self.mqtt.on_connect = self._on_connect self.mqtt.on_message = self._on_message self.mqtt.on_disconnect = self._on_disconnect # ── MQTT callbacks (run in paho thread) ────────────────── def _on_connect(self, client, userdata, flags, rc, props=None): if rc == 0: self.connected = True logger.info("✅ MQTT connected (%s:%s %s)", config.BROKER_HOST, config.BROKER_PORT, config.WS_TRANSPORT) client.subscribe("arena-ai/+/response") client.subscribe("arena-ai/global/heartbeat") self._log("system", "mqtt/connect", "Connected to broker") else: logger.error("❌ MQTT connect failed rc=%s", rc) def _on_disconnect(self, client, userdata, flags, rc, props=None): self.connected = False logger.warning("⚠️ MQTT disconnected rc=%s — will auto-reconnect", rc) self._log("system", "mqtt/disconnect", f"Disconnected rc={rc}") def _on_message(self, client, userdata, msg): try: topic = msg.topic payload = json.loads(msg.payload.decode()) if topic == "arena-ai/global/heartbeat": sid = payload.get("id") if sid: self.workers[sid] = dict( last_seen=time.time(), model=payload.get("model", "AI-Worker"), status=payload.get("status", "ready"), pressure=payload.get("pressure", 0), ) self.stats["heartbeats_rx"] += 1 self._log("heartbeat", topic, f"{sid} p={payload.get('pressure',0)}") return if topic.endswith("/response"): rid = payload.get("id") if rid and rid in self._queues and self._loop: self.stats["total_chunks"] += 1 self._loop.call_soon_threadsafe(self._queues[rid].put_nowait, payload) self._log("response", topic, f"{rid}") except Exception as exc: logger.error("Message parse error: %s", exc) # ── helpers ────────────────────────────────────────────── def _log(self, kind: str, topic: str, summary: str): self.activity_log.append(dict( ts=time.time(), time=time.strftime("%H:%M:%S"), kind=kind, topic=topic, summary=summary, )) def set_loop(self, loop): self._loop = loop def get_active_workers(self) -> Dict[str, Dict]: now = time.time() expired = [s for s, i in self.workers.items() if now - i["last_seen"] >= config.SESSION_EXPIRY] for s in expired: del self.workers[s] return dict(self.workers) # ── core chat generator ────────────────────────────────── async def chat(self, req: ChatCompletionRequest) -> AsyncGenerator[Dict, None]: self.stats["total_requests"] += 1 self.stats["active_streams"] += 1 active = self.get_active_workers() target = None # direct model:sid routing if ":" in req.model: sid = req.model.rsplit(":", 1)[-1] if sid in active: target = sid # least-pressure routing if not target: cands = [(s, i) for s, i in active.items() if (req.model in i["model"] or req.model == "auto") and i["status"] == "ready"] if cands: cands.sort(key=lambda x: x[1]["pressure"]) target = cands[0][0] if not target: self.stats["active_streams"] -= 1 self.stats["failed"] += 1 raise HTTPException(503, "No active workers. Open a Zen Bridge tab.") rid = f"req-{uuid.uuid4().hex[:12]}" q: asyncio.Queue = asyncio.Queue() self._queues[rid] = q mqtt_payload = dict( id=rid, messages=[m.model_dump() for m in req.messages], stream=req.stream, temperature=req.temperature, ) logger.info("📤 %s → %s (%s)", rid, active[target]["model"], target) self._log("request", f"arena-ai/{target}/request", rid) try: self.mqtt.publish( f"arena-ai/{target}/request", json.dumps(mqtt_payload), qos=1 ) deadline = time.time() + config.TIMEOUT_SEC while True: remaining = deadline - time.time() if remaining <= 0: self.stats["failed"] += 1 raise HTTPException(504, "Worker response timeout") try: chunk = await asyncio.wait_for(q.get(), timeout=min(remaining, 30)) except asyncio.TimeoutError: continue yield chunk choices = chunk.get("choices", []) if (choices and choices[0].get("finish_reason")) \ or chunk.get("object") == "chat.completion": self.stats["completed"] += 1 break except HTTPException: raise except Exception as exc: self.stats["failed"] += 1 logger.error("Chat error: %s", exc) raise HTTPException(502, str(exc)) finally: self.stats["active_streams"] -= 1 self._queues.pop(rid, None) engine = ProxyEngine() # ════════════════════════════════════════════════════════════ # HTML — Landing Page # ════════════════════════════════════════════════════════════ LANDING_HTML = """
OpenAI-Compatible MQTT Bridge
Select a model and start a conversation