Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import traceback | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException, Query | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import JSONResponse, EventSourceResponse | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| import uvicorn | |
| # ============================================================================= | |
| # Centralized JSON Logging | |
| # ============================================================================= | |
| class JSONFormatter(logging.Formatter): | |
| """Custom formatter that outputs log records as JSON.""" | |
| def format(self, record: logging.LogRecord) -> str: | |
| log_entry: Dict[str, Any] = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "level": record.levelname, | |
| "logger": record.name, | |
| "message": record.getMessage(), | |
| } | |
| # Add exception info if present | |
| if record.exc_info: | |
| log_entry["exception"] = { | |
| "type": record.exc_info[0].__name__ if record.exc_info[0] else None, | |
| "message": str(record.exc_info[1]) if record.exc_info[1] else None, | |
| "traceback": self.formatException(record.exc_info) if record.exc_info else None, | |
| } | |
| # Add extra fields from record | |
| if hasattr(record, "extra_fields"): | |
| log_entry.update(record.extra_fields) | |
| return json.dumps(log_entry) | |
| def setup_logging() -> None: | |
| """Configure centralized JSON logging for all loggers.""" | |
| # Root logger configuration | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.INFO) | |
| # Console handler with JSON formatter | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setFormatter(JSONFormatter()) | |
| root_logger.addHandler(console_handler) | |
| # Uvicorn access logger - configure for JSON output | |
| access_logger = logging.getLogger("uvicorn.access") | |
| access_logger.handlers.clear() | |
| access_logger.propagate = False | |
| access_handler = logging.StreamHandler(sys.stdout) | |
| access_handler.setFormatter(JSONFormatter()) | |
| access_logger.addHandler(access_handler) | |
| # Uvicorn error logger | |
| error_logger = logging.getLogger("uvicorn.error") | |
| error_logger.handlers.clear() | |
| error_logger.propagate = True | |
| # Setup logging on import | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # FastAPI Application with Global Exception Handler | |
| # ============================================================================= | |
| # Rate limiter: 60 requests per minute per IP | |
| limiter = Limiter(key_func=get_remote_address) | |
| app = FastAPI() | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # ============================================================================= | |
| # CORS Middleware with WebSocket Support | |
| # ============================================================================= | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # WebSocket Connection Manager | |
| # ============================================================================= | |
| class ConnectionManager: | |
| """Manages WebSocket client connections.""" | |
| def __init__(self) -> None: | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, client_id: str, websocket: WebSocket) -> bool: | |
| """ | |
| Connect a client. Returns True if new connection, False if duplicate. | |
| Duplicate connections are rejected to avoid ambiguity. | |
| """ | |
| if client_id in self.active_connections: | |
| logger.warning( | |
| f"Duplicate connection attempt for client_id: {client_id}", | |
| extra={"client_id": client_id, "action": "reject_duplicate"} | |
| ) | |
| return False | |
| await websocket.accept() | |
| self.active_connections[client_id] = websocket | |
| logger.info( | |
| f"Client connected: {client_id}", | |
| extra={"client_id": client_id, "action": "connect", "active_count": len(self.active_connections)} | |
| ) | |
| return True | |
| def disconnect(self, client_id: str) -> None: | |
| """Disconnect a client, removing from active connections.""" | |
| if client_id in self.active_connections: | |
| del self.active_connections[client_id] | |
| logger.info( | |
| f"Client disconnected: {client_id}", | |
| extra={"client_id": client_id, "action": "disconnect", "active_count": len(self.active_connections)} | |
| ) | |
| async def broadcast(self, message: Any) -> None: | |
| """Broadcast a JSON message to all active connections.""" | |
| if not self.active_connections: | |
| return | |
| disconnected_clients = [] | |
| for client_id, websocket in self.active_connections.items(): | |
| try: | |
| await websocket.send_json(message) | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to send to client {client_id}: {e}", | |
| extra={"client_id": client_id, "error": str(e)} | |
| ) | |
| disconnected_clients.append(client_id) | |
| # Clean up disconnected clients | |
| for client_id in disconnected_clients: | |
| self.disconnect(client_id) | |
| # ============================================================================= | |
| # CollaborationHub - Room-based WebSocket Message Broker | |
| # ============================================================================= | |
| class CollaborationHub: | |
| """ | |
| Centralized WebSocket message broker supporting distinct rooms/topics. | |
| Features: | |
| - Rooms: adam, eve, system, dialogue (for Agent Dialogue simulation) | |
| - Automatic client connection tracking per room | |
| - Broadcast to specific rooms or all rooms | |
| - Targeted messaging to individual clients | |
| """ | |
| # Available room types | |
| ROOMS = {"adam", "eve", "system", "dialogue"} | |
| def __init__(self) -> None: | |
| # Room -> {client_id -> WebSocket} | |
| self.rooms: Dict[str, Dict[str, WebSocket]] = {room: {} for room in self.ROOMS} | |
| # Client tracking: client_id -> room | |
| self.client_room: Dict[str, str] = {} | |
| # Client info: client_id -> metadata | |
| self.client_info: Dict[str, Dict[str, Any]] = {} | |
| async def connect(self, client_id: str, websocket: WebSocket, room: str = "system") -> bool: | |
| """ | |
| Connect a client to a specific room. | |
| Args: | |
| client_id: Unique identifier for the client | |
| websocket: WebSocket connection object | |
| room: Room to join (adam, eve, system, dialogue) | |
| Returns: | |
| True if connected, False if duplicate or invalid room | |
| """ | |
| if room not in self.ROOMS: | |
| logger.warning( | |
| f"Invalid room: {room}", | |
| extra={"client_id": client_id, "room": room, "action": "reject_invalid_room"} | |
| ) | |
| return False | |
| # Reject if client already connected to any room | |
| if client_id in self.client_room: | |
| existing_room = self.client_room[client_id] | |
| logger.warning( | |
| f"Client {client_id} already in room {existing_room}", | |
| extra={"client_id": client_id, "existing_room": existing_room, "action": "reject_duplicate"} | |
| ) | |
| return False | |
| await websocket.accept() | |
| self.rooms[room][client_id] = websocket | |
| self.client_room[client_id] = room | |
| self.client_info[client_id] = { | |
| "room": room, | |
| "connected_at": datetime.utcnow().isoformat() | |
| } | |
| logger.info( | |
| f"Client {client_id} joined room: {room}", | |
| extra={ | |
| "client_id": client_id, | |
| "room": room, | |
| "action": "join_room", | |
| "room_members": len(self.rooms[room]) | |
| } | |
| ) | |
| # Notify client of successful join | |
| await websocket.send_json({ | |
| "type": "room_joined", | |
| "room": room, | |
| "client_id": client_id, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| return True | |
| def disconnect(self, client_id: str) -> None: | |
| """Disconnect a client from their room.""" | |
| if client_id not in self.client_room: | |
| return | |
| room = self.client_room[client_id] | |
| if client_id in self.rooms[room]: | |
| del self.rooms[room][client_id] | |
| del self.client_room[client_id] | |
| if client_id in self.client_info: | |
| del self.client_info[client_id] | |
| logger.info( | |
| f"Client {client_id} left room: {room}", | |
| extra={ | |
| "client_id": client_id, | |
| "room": room, | |
| "action": "leave_room", | |
| "room_members": len(self.rooms[room]) | |
| } | |
| ) | |
| async def send_to_room(self, room: str, message: Any) -> int: | |
| """ | |
| Send a message to all clients in a specific room. | |
| Args: | |
| room: Target room name | |
| message: Message payload (will be JSON serialized) | |
| Returns: | |
| Number of clients the message was sent to | |
| """ | |
| if room not in self.ROOMS: | |
| logger.warning(f"Cannot send to invalid room: {room}") | |
| return 0 | |
| if not self.rooms[room]: | |
| return 0 | |
| disconnected = [] | |
| sent_count = 0 | |
| for client_id, websocket in self.rooms[room].items(): | |
| try: | |
| await websocket.send_json(message) | |
| sent_count += 1 | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to send to {client_id} in room {room}: {e}", | |
| extra={"client_id": client_id, "room": room, "error": str(e)} | |
| ) | |
| disconnected.append(client_id) | |
| # Clean up disconnected clients | |
| for client_id in disconnected: | |
| self.disconnect(client_id) | |
| return sent_count | |
| async def send_to_client(self, client_id: str, message: Any) -> bool: | |
| """ | |
| Send a message to a specific client. | |
| Args: | |
| client_id: Target client ID | |
| message: Message payload (will be JSON serialized) | |
| Returns: | |
| True if sent, False if client not found or send failed | |
| """ | |
| if client_id not in self.client_room: | |
| return False | |
| room = self.client_room[client_id] | |
| if client_id not in self.rooms[room]: | |
| return False | |
| try: | |
| await self.rooms[room][client_id].send_json(message) | |
| return True | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to send to {client_id}: {e}", | |
| extra={"client_id": client_id, "error": str(e)} | |
| ) | |
| self.disconnect(client_id) | |
| return False | |
| async def broadcast(self, message: Any, exclude_room: str = None) -> int: | |
| """ | |
| Broadcast a message to all rooms (or all except one). | |
| Args: | |
| message: Message payload (will be JSON serialized) | |
| exclude_room: Optional room to exclude from broadcast | |
| Returns: | |
| Total number of clients the message was sent to | |
| """ | |
| total_sent = 0 | |
| for room in self.ROOMS: | |
| if exclude_room and room == exclude_room: | |
| continue | |
| total_sent += await self.send_to_room(room, message) | |
| return total_sent | |
| async def simulate_agent_dialogue(self, speaker: str, listener: str, message: str) -> None: | |
| """ | |
| Simulate inter-agent dialogue by sending a message from one agent to another. | |
| Args: | |
| speaker: Room name of the speaking agent (adam or eve) | |
| listener: Room name of the listening agent (adam or eve) | |
| message: The message content | |
| """ | |
| if speaker not in {"adam", "eve"} or listener not in {"adam", "eve"}: | |
| logger.warning(f"Invalid agent dialogue: {speaker} -> {listener}") | |
| return | |
| dialogue_message = { | |
| "type": "agent_dialogue", | |
| "speaker": speaker, | |
| "listener": listener, | |
| "message": message, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| # Send to listener's room | |
| sent_count = await self.send_to_room(listener, dialogue_message) | |
| # Also broadcast to system room for logging | |
| await self.send_to_room("system", { | |
| **dialogue_message, | |
| "type": "agent_dialogue_log" | |
| }) | |
| logger.info( | |
| f"Agent dialogue: {speaker} -> {listener}: {message[:50]}...", | |
| extra={ | |
| "speaker": speaker, | |
| "listener": listener, | |
| "recipients": sent_count, | |
| "message_length": len(message) | |
| } | |
| ) | |
| def get_room_status(self) -> Dict[str, Any]: | |
| """Return current status of all rooms.""" | |
| return { | |
| room: { | |
| "members": list(clients.keys()), | |
| "count": len(clients) | |
| } | |
| for room, clients in self.rooms.items() | |
| } | |
| def get_client_info(self, client_id: str) -> Dict[str, Any]: | |
| """Get info about a specific client.""" | |
| return self.client_info.get(client_id, {}) | |
| # Global manager instances | |
| manager = ConnectionManager() | |
| hub = CollaborationHub() | |
| async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse: | |
| """Global exception handler that catches all exceptions.""" | |
| exc_type = type(exc).__name__ | |
| exc_message = str(exc) | |
| exc_traceback = traceback.format_exc() | |
| # Log the full exception with stacktrace | |
| logger.error( | |
| f"Unhandled exception: {exc_type}: {exc_message}", | |
| extra={ | |
| "exception_type": exc_type, | |
| "exception_message": exc_message, | |
| "traceback": exc_traceback, | |
| "path": request.url.path, | |
| "method": request.method, | |
| }, | |
| exc_info=True, | |
| ) | |
| # Return 500 JSON response with error details | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": { | |
| "type": exc_type, | |
| "message": exc_message, | |
| } | |
| }, | |
| ) | |
| async def validation_exception_handler( | |
| request: Request, exc: RequestValidationError | |
| ) -> JSONResponse: | |
| """Handler for FastAPI request validation errors.""" | |
| logger.warning( | |
| f"Validation error: {exc.errors()}", | |
| extra={"validation_errors": exc.errors(), "path": request.url.path}, | |
| ) | |
| return JSONResponse( | |
| status_code=422, | |
| content={"error": {"type": "ValidationError", "details": exc.errors()}}, | |
| ) | |
| # Track brain module status | |
| _brain_loaded = False | |
| try: | |
| import brain_minimal | |
| _brain_loaded = True | |
| except ImportError: | |
| _brain_loaded = False | |
| # Worker state tracking (updated via /internal/heartbeat from worker) | |
| _worker_state = { | |
| "worker_pid": None, | |
| "worker_mode": None, | |
| "worker_state": "inactive", | |
| "current_state": "idle", | |
| "worker_active": False, | |
| "stage": "STARTUP", | |
| "last_heartbeat": None, | |
| } | |
| # ============================================================================= | |
| # CollaborationHub HTTP API Endpoints | |
| # ============================================================================= | |
| async def hub_status(request: Request): | |
| """Get current CollaborationHub room status.""" | |
| return { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "rooms": hub.get_room_status(), | |
| "total_clients": len(hub.client_room) | |
| } | |
| async def trigger_agent_dialogue(request: Request): | |
| """ | |
| Trigger an agent dialogue simulation. | |
| Body: {"speaker": "adam", "listener": "eve", "message": "Hello!"} | |
| """ | |
| try: | |
| payload = await request.json() | |
| speaker = payload.get("speaker", "adam") | |
| listener = payload.get("listener", "eve") | |
| message = payload.get("message", "") | |
| if not message: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "message is required"} | |
| ) | |
| await hub.simulate_agent_dialogue(speaker, listener, message) | |
| return { | |
| "status": "sent", | |
| "speaker": speaker, | |
| "listener": listener, | |
| "message": message, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| except Exception as e: | |
| logger.error(f"Dialogue trigger error: {e}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| async def broadcast_to_room(request: Request): | |
| """ | |
| Broadcast a message to a specific room or all rooms. | |
| Body: {"room": "adam", "data": {...}} or {"broadcast_all": true, "data": {...}} | |
| """ | |
| try: | |
| payload = await request.json() | |
| data = payload.get("data", {}) | |
| room = payload.get("room") | |
| broadcast_all = payload.get("broadcast_all", False) | |
| if broadcast_all: | |
| recipients = await hub.broadcast(data) | |
| return { | |
| "status": "broadcast", | |
| "recipients": recipients, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| elif room: | |
| if room not in hub.ROOMS: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": f"Invalid room. Available: {list(hub.ROOMS)}"} | |
| ) | |
| recipients = await hub.send_to_room(room, data) | |
| return { | |
| "status": "sent", | |
| "room": room, | |
| "recipients": recipients, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| else: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "Either 'room' or 'broadcast_all' is required"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Broadcast error: {e}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| # Higher limit for internal heartbeat | |
| async def receive_heartbeat(request: Request): | |
| """Receive heartbeat from worker process (in-memory IPC).""" | |
| global _worker_state | |
| try: | |
| payload = await request.json() | |
| _worker_state.update({ | |
| "worker_pid": payload.get("worker_pid"), | |
| "worker_mode": payload.get("worker_mode"), | |
| "worker_state": payload.get("worker_state"), | |
| "current_state": payload.get("current_state"), | |
| "worker_active": payload.get("worker_active", False), | |
| "stage": payload.get("stage"), | |
| "last_heartbeat": payload.get("timestamp"), | |
| }) | |
| logger.info("Heartbeat received", extra={"worker_state": _worker_state}) | |
| return {"status": "ok"} | |
| except Exception as e: | |
| logger.warning(f"Heartbeat parse error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| def _read_cain_status() -> Dict[str, Any]: | |
| """Read cain_status.json from common locations.""" | |
| data_dir = Path(os.environ.get("OPENCLAW_DATA_DIR", "/data")) | |
| status_paths = [ | |
| data_dir / "cain_status.json", | |
| Path("/data/cain_status.json"), | |
| Path("/app/.openclaw/agents/cain_status.json"), | |
| Path("/app/openclaw/.openclaw/agents/cain_status.json"), | |
| Path("/app/cain_status.json"), | |
| ] | |
| for status_path in status_paths: | |
| try: | |
| if status_path.exists(): | |
| with open(status_path, "r") as f: | |
| return json.load(f) | |
| except Exception: | |
| continue | |
| return {} | |
| async def get_state(request: Request): | |
| """Return current worker and server state.""" | |
| return { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "server": { | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| "brain_loaded": _brain_loaded, | |
| }, | |
| "worker": _worker_state, | |
| "cain_status": _read_cain_status(), | |
| } | |
| async def notify_client_disconnect(request: Request): | |
| """Receive client disconnect notification via sendBeacon.""" | |
| try: | |
| payload = await request.json() | |
| logger.info( | |
| "Client disconnect notification received", | |
| extra={ | |
| "action": "client_disconnect", | |
| "payload": payload | |
| } | |
| ) | |
| return {"status": "acknowledged"} | |
| except Exception: | |
| # sendBeacon may send malformed data during unload, just acknowledge | |
| return {"status": "acknowledged"} | |
| async def get_status(request: Request): | |
| """Health check endpoint for frontend polling.""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| async def event_stream(request: Request): | |
| """Async generator for SSE events with state updates and keep-alive.""" | |
| client_id = f"sse-{id(request)}" | |
| logger.info(f"SSE client connecting: {client_id}", extra={"client_id": client_id, "action": "sse_connect"}) | |
| try: | |
| while True: | |
| # Check if client has disconnected by inspecting the request | |
| # The EventSourceResponse handles this, but we add explicit logging | |
| if await request.is_disconnected(): | |
| logger.info(f"SSE client disconnected: {client_id}", extra={"client_id": client_id, "action": "sse_disconnect"}) | |
| break | |
| # Send state update event | |
| state_data = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "server": { | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| "brain_loaded": _brain_loaded, | |
| }, | |
| "worker": _worker_state, | |
| "cain_status": _read_cain_status(), | |
| } | |
| yield f"event: state_update\ndata: {json.dumps(state_data)}\n\n" | |
| # Wait 2 seconds between state updates | |
| await asyncio.sleep(2) | |
| # Send keep-alive comment every 15 seconds (after ~5 state updates) | |
| yield ":keep-alive\n\n" | |
| await asyncio.sleep(13) | |
| except asyncio.CancelledError: | |
| logger.info(f"SSE stream cancelled: {client_id}", extra={"client_id": client_id, "action": "sse_cancelled"}) | |
| raise | |
| except Exception as e: | |
| logger.error(f"SSE stream error for {client_id}: {e}", extra={"client_id": client_id, "error": str(e)}, exc_info=True) | |
| finally: | |
| logger.info(f"SSE client cleanup: {client_id}", extra={"client_id": client_id, "action": "sse_cleanup"}) | |
| async def sse_events(request: Request): | |
| """SSE endpoint for real-time state updates.""" | |
| return EventSourceResponse(event_stream(request)) | |
| async def health_check(request: Request): | |
| """Structured health check endpoint returning server status.""" | |
| logger.info("Health check requested", extra={"endpoint": "/health"}) | |
| return { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| "brain_loaded": _brain_loaded, | |
| } | |
| async def status_check(request: Request): | |
| """Comprehensive health check with db, env, and filesystem validation.""" | |
| checks = {"db": "ok", "env": "ok", "fs": "ok"} | |
| overall_healthy = True | |
| # 1. Database check - lightweight dataset list call | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| # Lightweight call - just check if we can reach HF | |
| repo_id = os.environ.get("OPENCLAW_DATASET_REPO", "") | |
| if repo_id: | |
| api.repo_info(repo_id=repo_id, repo_type="dataset") | |
| checks["db"] = "ok" | |
| else: | |
| checks["db"] = "error: OPENCLAW_DATASET_REPO not set" | |
| overall_healthy = False | |
| except Exception as e: | |
| checks["db"] = f"error: {str(e)}" | |
| overall_healthy = False | |
| # 2. Environment check - verify critical variables | |
| critical_vars = ["HF_TOKEN", "OPENCLAW_DATASET_REPO"] | |
| missing_vars = [v for v in critical_vars if not os.environ.get(v)] | |
| if missing_vars: | |
| checks["env"] = f"error: missing {', '.join(missing_vars)}" | |
| overall_healthy = False | |
| else: | |
| checks["env"] = "ok" | |
| # 3. Filesystem check - touch and delete test file | |
| test_file = Path("/tmp/test_write") | |
| try: | |
| test_file.touch() | |
| test_file.unlink() | |
| checks["fs"] = "ok" | |
| except Exception as e: | |
| checks["fs"] = f"error: {str(e)}" | |
| overall_healthy = False | |
| status = "healthy" if overall_healthy else "degraded" | |
| # Return HTTP 503 if critical checks fail | |
| if not overall_healthy: | |
| raise HTTPException(status_code=503, detail={ | |
| "status": status, | |
| "checks": checks | |
| }) | |
| return {"status": status, "checks": checks} | |
| # ============================================================================= | |
| # CollaborationHub WebSocket Endpoints | |
| # ============================================================================= | |
| async def websocket_hub_endpoint( | |
| websocket: WebSocket, | |
| client_id: str = Query(..., description="Client identifier required"), | |
| room: str = Query("system", description="Room to join: adam, eve, system, dialogue") | |
| ) -> None: | |
| """ | |
| CollaborationHub WebSocket endpoint with room-based messaging. | |
| Rooms: | |
| - adam: Adam agent messages and updates | |
| - eve: Eve agent messages and updates | |
| - system: System-wide notifications | |
| - dialogue: Inter-agent dialogue simulation | |
| Requires `client_id` and optionally `room` query parameters. | |
| """ | |
| if not client_id: | |
| await websocket.close(code=4008, reason="client_id query parameter required") | |
| logger.warning("Hub connection rejected: missing client_id") | |
| return | |
| # Connect to the specified room | |
| connected = await hub.connect(client_id, websocket, room) | |
| if not connected: | |
| await websocket.close(code=4009, reason=f"Failed to join room: {room}") | |
| return | |
| try: | |
| while True: | |
| # Receive and handle JSON message from client | |
| try: | |
| data = await websocket.receive_json() | |
| msg_type = data.get("type", "unknown") | |
| logger.info( | |
| f"Hub message from {client_id} in room {room}: {msg_type}", | |
| extra={"client_id": client_id, "room": room, "msg_type": msg_type} | |
| ) | |
| # Handle different message types | |
| if msg_type == "broadcast": | |
| # Broadcast to all rooms | |
| response = { | |
| "type": "broadcast_result", | |
| "from_client": client_id, | |
| "from_room": room, | |
| "data": data.get("data"), | |
| "recipients": await hub.broadcast(data, exclude_room=None), | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| await websocket.send_json(response) | |
| elif msg_type == "room_message": | |
| # Send to specific room | |
| target_room = data.get("target_room", "system") | |
| recipients = await hub.send_to_room(target_room, data.get("data", {})) | |
| await websocket.send_json({ | |
| "type": "room_message_result", | |
| "target_room": target_room, | |
| "recipients": recipients, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| elif msg_type == "agent_dialogue": | |
| # Simulate agent-to-agent dialogue | |
| speaker = data.get("speaker", room) | |
| listener = data.get("listener", "eve" if speaker == "adam" else "adam") | |
| message = data.get("message", "") | |
| await hub.simulate_agent_dialogue(speaker, listener, message) | |
| elif msg_type == "status": | |
| # Get room status | |
| await websocket.send_json({ | |
| "type": "room_status", | |
| "status": hub.get_room_status(), | |
| "your_room": room, | |
| "your_info": hub.get_client_info(client_id), | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| else: | |
| # Echo back acknowledgment | |
| await websocket.send_json({ | |
| "type": "ack", | |
| "client_id": client_id, | |
| "room": room, | |
| "received_type": msg_type, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| except json.JSONDecodeError as e: | |
| logger.warning( | |
| f"Invalid JSON from {client_id}: {e}", | |
| extra={"client_id": client_id, "error": str(e)} | |
| ) | |
| await websocket.send_json({"error": "invalid_json", "message": "Failed to parse JSON message"}) | |
| except WebSocketDisconnect: | |
| logger.info(f"Hub WebSocket disconnected: {client_id}", extra={"client_id": client_id, "room": room}) | |
| except Exception as e: | |
| logger.error( | |
| f"Hub WebSocket error for {client_id}: {e}", | |
| extra={"client_id": client_id, "room": room, "error": str(e)}, | |
| exc_info=True | |
| ) | |
| finally: | |
| # Always clean up connection | |
| hub.disconnect(client_id) | |
| async def websocket_endpoint( | |
| websocket: WebSocket, | |
| client_id: str = Query(..., description="Client identifier required") | |
| ) -> None: | |
| """ | |
| WebSocket endpoint with client identification. | |
| Requires `client_id` query parameter. Returns 400 if missing. | |
| Handles JSON parsing errors and graceful disconnects. | |
| """ | |
| # Reject if client_id is missing (FastAPI should handle this, but double-check) | |
| if not client_id: | |
| await websocket.close(code=4008, reason="client_id query parameter required") | |
| logger.warning("WebSocket connection rejected: missing client_id") | |
| return | |
| # Connect using manager (handles duplicate detection) | |
| connected = await manager.connect(client_id, websocket) | |
| if not connected: | |
| await websocket.close(code=4009, reason="client_id already connected") | |
| return | |
| try: | |
| while True: | |
| # Receive and parse JSON message from client | |
| try: | |
| data = await websocket.receive_json() | |
| logger.info( | |
| f"Message from {client_id}", | |
| extra={"client_id": client_id, "message_type": type(data).__name__} | |
| ) | |
| # Echo back with acknowledgment (or broadcast as needed) | |
| response = {"status": "received", "client_id": client_id, "data": data} | |
| await websocket.send_json(response) | |
| except json.JSONDecodeError as e: | |
| logger.warning( | |
| f"Invalid JSON from {client_id}: {e}", | |
| extra={"client_id": client_id, "error": str(e)} | |
| ) | |
| await websocket.send_json({"error": "invalid_json", "message": "Failed to parse JSON message"}) | |
| except WebSocketDisconnect: | |
| logger.info(f"WebSocket disconnected: {client_id}", extra={"client_id": client_id, "event": "disconnect"}) | |
| except Exception as e: | |
| logger.error( | |
| f"WebSocket error for {client_id}: {e}", | |
| extra={"client_id": client_id, "error": str(e)}, | |
| exc_info=True | |
| ) | |
| finally: | |
| # Always clean up connection | |
| manager.disconnect(client_id) | |
| async def websocket_agents(websocket: WebSocket): | |
| """WebSocket endpoint for real-time agent status updates.""" | |
| await websocket.accept() | |
| logger.info("WebSocket client connected", extra={"endpoint": "/ws/agents"}) | |
| try: | |
| while True: | |
| # Send agent status every 2 seconds | |
| await websocket.send_json({ | |
| "agents": [ | |
| {"name": "Adam", "status": "Thinking"}, | |
| {"name": "Eve", "status": "Writing"} | |
| ] | |
| }) | |
| await asyncio.sleep(2) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket client disconnected", extra={"endpoint": "/ws/agents"}) | |
| async def websocket_state(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint for real-time state updates. | |
| Replaces polling /api/state. Pushes state updates to connected clients | |
| whenever the worker state changes. Initial state sent immediately on connect. | |
| """ | |
| await websocket.accept() | |
| client_id = f"state-{id(websocket)}" | |
| logger.info("State WebSocket client connected", extra={"endpoint": "/ws/state", "client_id": client_id}) | |
| # Send initial state immediately | |
| try: | |
| initial_state = { | |
| "type": "state", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "data": { | |
| "server": { | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| "brain_loaded": _brain_loaded, | |
| }, | |
| "worker": _worker_state, | |
| } | |
| } | |
| await websocket.send_json(initial_state) | |
| logger.info("Initial state sent", extra={"client_id": client_id}) | |
| except Exception as e: | |
| logger.error(f"Failed to send initial state: {e}", extra={"client_id": client_id}) | |
| await websocket.close() | |
| return | |
| # Track this connection for broadcasts | |
| manager.active_connections[client_id] = websocket | |
| try: | |
| while True: | |
| # Send periodic state updates (every 2 seconds) | |
| # This is a simple push mechanism - could be enhanced to only send on changes | |
| state_update = { | |
| "type": "state_update", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "data": { | |
| "server": { | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| "brain_loaded": _brain_loaded, | |
| }, | |
| "worker": _worker_state, | |
| } | |
| } | |
| await websocket.send_json(state_update) | |
| await asyncio.sleep(2) | |
| except WebSocketDisconnect: | |
| logger.info("State WebSocket client disconnected", extra={"endpoint": "/ws/state", "client_id": client_id}) | |
| except Exception as e: | |
| logger.error(f"State WebSocket error: {e}", extra={"endpoint": "/ws/state", "client_id": client_id}, exc_info=True) | |
| finally: | |
| manager.disconnect(client_id) | |
| # Mount static files to serve frontend at root path | |
| app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend") | |
| # ============================================================================= | |
| # Main Entry Point | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| # Clear startup logging | |
| logger.info( | |
| "Cain's Space starting", | |
| extra={ | |
| "event": "startup", | |
| "port": 7860, | |
| "host": "0.0.0.0", | |
| "brain_loaded": _brain_loaded, | |
| "startup_mode": os.environ.get("STARTUP_MODE", "unknown"), | |
| }, | |
| ) | |
| logger.info("Binding to port 7860 on 0.0.0.0", extra={"event": "port_binding", "port": 7860}) | |
| # Run uvicorn with JSON access logging | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| access_log=True, | |
| log_config=None, # Use our custom logging setup | |
| ) | |