Claude Code
Claude Code: Fix WebSocket/SSE disconnect handling and eliminate zombie polling
3ef499f
import asyncio
import json
import logging
import os
import sys
import traceback
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, EventSourceResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import uvicorn
# =============================================================================
# Centralized JSON Logging
# =============================================================================
class JSONFormatter(logging.Formatter):
"""Custom formatter that outputs log records as JSON."""
def format(self, record: logging.LogRecord) -> str:
log_entry: Dict[str, Any] = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
# Add exception info if present
if record.exc_info:
log_entry["exception"] = {
"type": record.exc_info[0].__name__ if record.exc_info[0] else None,
"message": str(record.exc_info[1]) if record.exc_info[1] else None,
"traceback": self.formatException(record.exc_info) if record.exc_info else None,
}
# Add extra fields from record
if hasattr(record, "extra_fields"):
log_entry.update(record.extra_fields)
return json.dumps(log_entry)
def setup_logging() -> None:
"""Configure centralized JSON logging for all loggers."""
# Root logger configuration
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
# Console handler with JSON formatter
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(JSONFormatter())
root_logger.addHandler(console_handler)
# Uvicorn access logger - configure for JSON output
access_logger = logging.getLogger("uvicorn.access")
access_logger.handlers.clear()
access_logger.propagate = False
access_handler = logging.StreamHandler(sys.stdout)
access_handler.setFormatter(JSONFormatter())
access_logger.addHandler(access_handler)
# Uvicorn error logger
error_logger = logging.getLogger("uvicorn.error")
error_logger.handlers.clear()
error_logger.propagate = True
# Setup logging on import
setup_logging()
logger = logging.getLogger(__name__)
# =============================================================================
# FastAPI Application with Global Exception Handler
# =============================================================================
# Rate limiter: 60 requests per minute per IP
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# =============================================================================
# CORS Middleware with WebSocket Support
# =============================================================================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# WebSocket Connection Manager
# =============================================================================
class ConnectionManager:
"""Manages WebSocket client connections."""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
async def connect(self, client_id: str, websocket: WebSocket) -> bool:
"""
Connect a client. Returns True if new connection, False if duplicate.
Duplicate connections are rejected to avoid ambiguity.
"""
if client_id in self.active_connections:
logger.warning(
f"Duplicate connection attempt for client_id: {client_id}",
extra={"client_id": client_id, "action": "reject_duplicate"}
)
return False
await websocket.accept()
self.active_connections[client_id] = websocket
logger.info(
f"Client connected: {client_id}",
extra={"client_id": client_id, "action": "connect", "active_count": len(self.active_connections)}
)
return True
def disconnect(self, client_id: str) -> None:
"""Disconnect a client, removing from active connections."""
if client_id in self.active_connections:
del self.active_connections[client_id]
logger.info(
f"Client disconnected: {client_id}",
extra={"client_id": client_id, "action": "disconnect", "active_count": len(self.active_connections)}
)
async def broadcast(self, message: Any) -> None:
"""Broadcast a JSON message to all active connections."""
if not self.active_connections:
return
disconnected_clients = []
for client_id, websocket in self.active_connections.items():
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(
f"Failed to send to client {client_id}: {e}",
extra={"client_id": client_id, "error": str(e)}
)
disconnected_clients.append(client_id)
# Clean up disconnected clients
for client_id in disconnected_clients:
self.disconnect(client_id)
# =============================================================================
# CollaborationHub - Room-based WebSocket Message Broker
# =============================================================================
class CollaborationHub:
"""
Centralized WebSocket message broker supporting distinct rooms/topics.
Features:
- Rooms: adam, eve, system, dialogue (for Agent Dialogue simulation)
- Automatic client connection tracking per room
- Broadcast to specific rooms or all rooms
- Targeted messaging to individual clients
"""
# Available room types
ROOMS = {"adam", "eve", "system", "dialogue"}
def __init__(self) -> None:
# Room -> {client_id -> WebSocket}
self.rooms: Dict[str, Dict[str, WebSocket]] = {room: {} for room in self.ROOMS}
# Client tracking: client_id -> room
self.client_room: Dict[str, str] = {}
# Client info: client_id -> metadata
self.client_info: Dict[str, Dict[str, Any]] = {}
async def connect(self, client_id: str, websocket: WebSocket, room: str = "system") -> bool:
"""
Connect a client to a specific room.
Args:
client_id: Unique identifier for the client
websocket: WebSocket connection object
room: Room to join (adam, eve, system, dialogue)
Returns:
True if connected, False if duplicate or invalid room
"""
if room not in self.ROOMS:
logger.warning(
f"Invalid room: {room}",
extra={"client_id": client_id, "room": room, "action": "reject_invalid_room"}
)
return False
# Reject if client already connected to any room
if client_id in self.client_room:
existing_room = self.client_room[client_id]
logger.warning(
f"Client {client_id} already in room {existing_room}",
extra={"client_id": client_id, "existing_room": existing_room, "action": "reject_duplicate"}
)
return False
await websocket.accept()
self.rooms[room][client_id] = websocket
self.client_room[client_id] = room
self.client_info[client_id] = {
"room": room,
"connected_at": datetime.utcnow().isoformat()
}
logger.info(
f"Client {client_id} joined room: {room}",
extra={
"client_id": client_id,
"room": room,
"action": "join_room",
"room_members": len(self.rooms[room])
}
)
# Notify client of successful join
await websocket.send_json({
"type": "room_joined",
"room": room,
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat()
})
return True
def disconnect(self, client_id: str) -> None:
"""Disconnect a client from their room."""
if client_id not in self.client_room:
return
room = self.client_room[client_id]
if client_id in self.rooms[room]:
del self.rooms[room][client_id]
del self.client_room[client_id]
if client_id in self.client_info:
del self.client_info[client_id]
logger.info(
f"Client {client_id} left room: {room}",
extra={
"client_id": client_id,
"room": room,
"action": "leave_room",
"room_members": len(self.rooms[room])
}
)
async def send_to_room(self, room: str, message: Any) -> int:
"""
Send a message to all clients in a specific room.
Args:
room: Target room name
message: Message payload (will be JSON serialized)
Returns:
Number of clients the message was sent to
"""
if room not in self.ROOMS:
logger.warning(f"Cannot send to invalid room: {room}")
return 0
if not self.rooms[room]:
return 0
disconnected = []
sent_count = 0
for client_id, websocket in self.rooms[room].items():
try:
await websocket.send_json(message)
sent_count += 1
except Exception as e:
logger.warning(
f"Failed to send to {client_id} in room {room}: {e}",
extra={"client_id": client_id, "room": room, "error": str(e)}
)
disconnected.append(client_id)
# Clean up disconnected clients
for client_id in disconnected:
self.disconnect(client_id)
return sent_count
async def send_to_client(self, client_id: str, message: Any) -> bool:
"""
Send a message to a specific client.
Args:
client_id: Target client ID
message: Message payload (will be JSON serialized)
Returns:
True if sent, False if client not found or send failed
"""
if client_id not in self.client_room:
return False
room = self.client_room[client_id]
if client_id not in self.rooms[room]:
return False
try:
await self.rooms[room][client_id].send_json(message)
return True
except Exception as e:
logger.warning(
f"Failed to send to {client_id}: {e}",
extra={"client_id": client_id, "error": str(e)}
)
self.disconnect(client_id)
return False
async def broadcast(self, message: Any, exclude_room: str = None) -> int:
"""
Broadcast a message to all rooms (or all except one).
Args:
message: Message payload (will be JSON serialized)
exclude_room: Optional room to exclude from broadcast
Returns:
Total number of clients the message was sent to
"""
total_sent = 0
for room in self.ROOMS:
if exclude_room and room == exclude_room:
continue
total_sent += await self.send_to_room(room, message)
return total_sent
async def simulate_agent_dialogue(self, speaker: str, listener: str, message: str) -> None:
"""
Simulate inter-agent dialogue by sending a message from one agent to another.
Args:
speaker: Room name of the speaking agent (adam or eve)
listener: Room name of the listening agent (adam or eve)
message: The message content
"""
if speaker not in {"adam", "eve"} or listener not in {"adam", "eve"}:
logger.warning(f"Invalid agent dialogue: {speaker} -> {listener}")
return
dialogue_message = {
"type": "agent_dialogue",
"speaker": speaker,
"listener": listener,
"message": message,
"timestamp": datetime.utcnow().isoformat()
}
# Send to listener's room
sent_count = await self.send_to_room(listener, dialogue_message)
# Also broadcast to system room for logging
await self.send_to_room("system", {
**dialogue_message,
"type": "agent_dialogue_log"
})
logger.info(
f"Agent dialogue: {speaker} -> {listener}: {message[:50]}...",
extra={
"speaker": speaker,
"listener": listener,
"recipients": sent_count,
"message_length": len(message)
}
)
def get_room_status(self) -> Dict[str, Any]:
"""Return current status of all rooms."""
return {
room: {
"members": list(clients.keys()),
"count": len(clients)
}
for room, clients in self.rooms.items()
}
def get_client_info(self, client_id: str) -> Dict[str, Any]:
"""Get info about a specific client."""
return self.client_info.get(client_id, {})
# Global manager instances
manager = ConnectionManager()
hub = CollaborationHub()
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Global exception handler that catches all exceptions."""
exc_type = type(exc).__name__
exc_message = str(exc)
exc_traceback = traceback.format_exc()
# Log the full exception with stacktrace
logger.error(
f"Unhandled exception: {exc_type}: {exc_message}",
extra={
"exception_type": exc_type,
"exception_message": exc_message,
"traceback": exc_traceback,
"path": request.url.path,
"method": request.method,
},
exc_info=True,
)
# Return 500 JSON response with error details
return JSONResponse(
status_code=500,
content={
"error": {
"type": exc_type,
"message": exc_message,
}
},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Handler for FastAPI request validation errors."""
logger.warning(
f"Validation error: {exc.errors()}",
extra={"validation_errors": exc.errors(), "path": request.url.path},
)
return JSONResponse(
status_code=422,
content={"error": {"type": "ValidationError", "details": exc.errors()}},
)
# Track brain module status
_brain_loaded = False
try:
import brain_minimal
_brain_loaded = True
except ImportError:
_brain_loaded = False
# Worker state tracking (updated via /internal/heartbeat from worker)
_worker_state = {
"worker_pid": None,
"worker_mode": None,
"worker_state": "inactive",
"current_state": "idle",
"worker_active": False,
"stage": "STARTUP",
"last_heartbeat": None,
}
# =============================================================================
# CollaborationHub HTTP API Endpoints
# =============================================================================
@app.get("/api/hub/status")
@limiter.limit("60/minute")
async def hub_status(request: Request):
"""Get current CollaborationHub room status."""
return {
"timestamp": datetime.utcnow().isoformat(),
"rooms": hub.get_room_status(),
"total_clients": len(hub.client_room)
}
@app.post("/api/hub/dialogue")
@limiter.limit("30/minute")
async def trigger_agent_dialogue(request: Request):
"""
Trigger an agent dialogue simulation.
Body: {"speaker": "adam", "listener": "eve", "message": "Hello!"}
"""
try:
payload = await request.json()
speaker = payload.get("speaker", "adam")
listener = payload.get("listener", "eve")
message = payload.get("message", "")
if not message:
return JSONResponse(
status_code=400,
content={"error": "message is required"}
)
await hub.simulate_agent_dialogue(speaker, listener, message)
return {
"status": "sent",
"speaker": speaker,
"listener": listener,
"message": message,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Dialogue trigger error: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.post("/api/hub/broadcast")
@limiter.limit("30/minute")
async def broadcast_to_room(request: Request):
"""
Broadcast a message to a specific room or all rooms.
Body: {"room": "adam", "data": {...}} or {"broadcast_all": true, "data": {...}}
"""
try:
payload = await request.json()
data = payload.get("data", {})
room = payload.get("room")
broadcast_all = payload.get("broadcast_all", False)
if broadcast_all:
recipients = await hub.broadcast(data)
return {
"status": "broadcast",
"recipients": recipients,
"timestamp": datetime.utcnow().isoformat()
}
elif room:
if room not in hub.ROOMS:
return JSONResponse(
status_code=400,
content={"error": f"Invalid room. Available: {list(hub.ROOMS)}"}
)
recipients = await hub.send_to_room(room, data)
return {
"status": "sent",
"room": room,
"recipients": recipients,
"timestamp": datetime.utcnow().isoformat()
}
else:
return JSONResponse(
status_code=400,
content={"error": "Either 'room' or 'broadcast_all' is required"}
)
except Exception as e:
logger.error(f"Broadcast error: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.post("/internal/heartbeat")
@limiter.limit("100/minute") # Higher limit for internal heartbeat
async def receive_heartbeat(request: Request):
"""Receive heartbeat from worker process (in-memory IPC)."""
global _worker_state
try:
payload = await request.json()
_worker_state.update({
"worker_pid": payload.get("worker_pid"),
"worker_mode": payload.get("worker_mode"),
"worker_state": payload.get("worker_state"),
"current_state": payload.get("current_state"),
"worker_active": payload.get("worker_active", False),
"stage": payload.get("stage"),
"last_heartbeat": payload.get("timestamp"),
})
logger.info("Heartbeat received", extra={"worker_state": _worker_state})
return {"status": "ok"}
except Exception as e:
logger.warning(f"Heartbeat parse error: {e}")
return {"status": "error", "message": str(e)}
def _read_cain_status() -> Dict[str, Any]:
"""Read cain_status.json from common locations."""
data_dir = Path(os.environ.get("OPENCLAW_DATA_DIR", "/data"))
status_paths = [
data_dir / "cain_status.json",
Path("/data/cain_status.json"),
Path("/app/.openclaw/agents/cain_status.json"),
Path("/app/openclaw/.openclaw/agents/cain_status.json"),
Path("/app/cain_status.json"),
]
for status_path in status_paths:
try:
if status_path.exists():
with open(status_path, "r") as f:
return json.load(f)
except Exception:
continue
return {}
@app.get("/api/state")
@limiter.limit("60/minute")
async def get_state(request: Request):
"""Return current worker and server state."""
return {
"timestamp": datetime.utcnow().isoformat(),
"server": {
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
"brain_loaded": _brain_loaded,
},
"worker": _worker_state,
"cain_status": _read_cain_status(),
}
@app.post("/api/state")
@limiter.limit("30/minute")
async def notify_client_disconnect(request: Request):
"""Receive client disconnect notification via sendBeacon."""
try:
payload = await request.json()
logger.info(
"Client disconnect notification received",
extra={
"action": "client_disconnect",
"payload": payload
}
)
return {"status": "acknowledged"}
except Exception:
# sendBeacon may send malformed data during unload, just acknowledge
return {"status": "acknowledged"}
@app.get("/api/status")
@limiter.limit("60/minute")
async def get_status(request: Request):
"""Health check endpoint for frontend polling."""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
}
async def event_stream(request: Request):
"""Async generator for SSE events with state updates and keep-alive."""
client_id = f"sse-{id(request)}"
logger.info(f"SSE client connecting: {client_id}", extra={"client_id": client_id, "action": "sse_connect"})
try:
while True:
# Check if client has disconnected by inspecting the request
# The EventSourceResponse handles this, but we add explicit logging
if await request.is_disconnected():
logger.info(f"SSE client disconnected: {client_id}", extra={"client_id": client_id, "action": "sse_disconnect"})
break
# Send state update event
state_data = {
"timestamp": datetime.utcnow().isoformat(),
"server": {
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
"brain_loaded": _brain_loaded,
},
"worker": _worker_state,
"cain_status": _read_cain_status(),
}
yield f"event: state_update\ndata: {json.dumps(state_data)}\n\n"
# Wait 2 seconds between state updates
await asyncio.sleep(2)
# Send keep-alive comment every 15 seconds (after ~5 state updates)
yield ":keep-alive\n\n"
await asyncio.sleep(13)
except asyncio.CancelledError:
logger.info(f"SSE stream cancelled: {client_id}", extra={"client_id": client_id, "action": "sse_cancelled"})
raise
except Exception as e:
logger.error(f"SSE stream error for {client_id}: {e}", extra={"client_id": client_id, "error": str(e)}, exc_info=True)
finally:
logger.info(f"SSE client cleanup: {client_id}", extra={"client_id": client_id, "action": "sse_cleanup"})
@app.get("/api/events")
@limiter.limit("60/minute")
async def sse_events(request: Request):
"""SSE endpoint for real-time state updates."""
return EventSourceResponse(event_stream(request))
@app.get("/health")
@limiter.limit("60/minute")
async def health_check(request: Request):
"""Structured health check endpoint returning server status."""
logger.info("Health check requested", extra={"endpoint": "/health"})
return {
"timestamp": datetime.utcnow().isoformat(),
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
"brain_loaded": _brain_loaded,
}
@app.get("/status")
@limiter.limit("60/minute")
async def status_check(request: Request):
"""Comprehensive health check with db, env, and filesystem validation."""
checks = {"db": "ok", "env": "ok", "fs": "ok"}
overall_healthy = True
# 1. Database check - lightweight dataset list call
try:
from huggingface_hub import HfApi
api = HfApi()
# Lightweight call - just check if we can reach HF
repo_id = os.environ.get("OPENCLAW_DATASET_REPO", "")
if repo_id:
api.repo_info(repo_id=repo_id, repo_type="dataset")
checks["db"] = "ok"
else:
checks["db"] = "error: OPENCLAW_DATASET_REPO not set"
overall_healthy = False
except Exception as e:
checks["db"] = f"error: {str(e)}"
overall_healthy = False
# 2. Environment check - verify critical variables
critical_vars = ["HF_TOKEN", "OPENCLAW_DATASET_REPO"]
missing_vars = [v for v in critical_vars if not os.environ.get(v)]
if missing_vars:
checks["env"] = f"error: missing {', '.join(missing_vars)}"
overall_healthy = False
else:
checks["env"] = "ok"
# 3. Filesystem check - touch and delete test file
test_file = Path("/tmp/test_write")
try:
test_file.touch()
test_file.unlink()
checks["fs"] = "ok"
except Exception as e:
checks["fs"] = f"error: {str(e)}"
overall_healthy = False
status = "healthy" if overall_healthy else "degraded"
# Return HTTP 503 if critical checks fail
if not overall_healthy:
raise HTTPException(status_code=503, detail={
"status": status,
"checks": checks
})
return {"status": status, "checks": checks}
# =============================================================================
# CollaborationHub WebSocket Endpoints
# =============================================================================
@app.websocket("/ws/hub")
async def websocket_hub_endpoint(
websocket: WebSocket,
client_id: str = Query(..., description="Client identifier required"),
room: str = Query("system", description="Room to join: adam, eve, system, dialogue")
) -> None:
"""
CollaborationHub WebSocket endpoint with room-based messaging.
Rooms:
- adam: Adam agent messages and updates
- eve: Eve agent messages and updates
- system: System-wide notifications
- dialogue: Inter-agent dialogue simulation
Requires `client_id` and optionally `room` query parameters.
"""
if not client_id:
await websocket.close(code=4008, reason="client_id query parameter required")
logger.warning("Hub connection rejected: missing client_id")
return
# Connect to the specified room
connected = await hub.connect(client_id, websocket, room)
if not connected:
await websocket.close(code=4009, reason=f"Failed to join room: {room}")
return
try:
while True:
# Receive and handle JSON message from client
try:
data = await websocket.receive_json()
msg_type = data.get("type", "unknown")
logger.info(
f"Hub message from {client_id} in room {room}: {msg_type}",
extra={"client_id": client_id, "room": room, "msg_type": msg_type}
)
# Handle different message types
if msg_type == "broadcast":
# Broadcast to all rooms
response = {
"type": "broadcast_result",
"from_client": client_id,
"from_room": room,
"data": data.get("data"),
"recipients": await hub.broadcast(data, exclude_room=None),
"timestamp": datetime.utcnow().isoformat()
}
await websocket.send_json(response)
elif msg_type == "room_message":
# Send to specific room
target_room = data.get("target_room", "system")
recipients = await hub.send_to_room(target_room, data.get("data", {}))
await websocket.send_json({
"type": "room_message_result",
"target_room": target_room,
"recipients": recipients,
"timestamp": datetime.utcnow().isoformat()
})
elif msg_type == "agent_dialogue":
# Simulate agent-to-agent dialogue
speaker = data.get("speaker", room)
listener = data.get("listener", "eve" if speaker == "adam" else "adam")
message = data.get("message", "")
await hub.simulate_agent_dialogue(speaker, listener, message)
elif msg_type == "status":
# Get room status
await websocket.send_json({
"type": "room_status",
"status": hub.get_room_status(),
"your_room": room,
"your_info": hub.get_client_info(client_id),
"timestamp": datetime.utcnow().isoformat()
})
else:
# Echo back acknowledgment
await websocket.send_json({
"type": "ack",
"client_id": client_id,
"room": room,
"received_type": msg_type,
"timestamp": datetime.utcnow().isoformat()
})
except json.JSONDecodeError as e:
logger.warning(
f"Invalid JSON from {client_id}: {e}",
extra={"client_id": client_id, "error": str(e)}
)
await websocket.send_json({"error": "invalid_json", "message": "Failed to parse JSON message"})
except WebSocketDisconnect:
logger.info(f"Hub WebSocket disconnected: {client_id}", extra={"client_id": client_id, "room": room})
except Exception as e:
logger.error(
f"Hub WebSocket error for {client_id}: {e}",
extra={"client_id": client_id, "room": room, "error": str(e)},
exc_info=True
)
finally:
# Always clean up connection
hub.disconnect(client_id)
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
client_id: str = Query(..., description="Client identifier required")
) -> None:
"""
WebSocket endpoint with client identification.
Requires `client_id` query parameter. Returns 400 if missing.
Handles JSON parsing errors and graceful disconnects.
"""
# Reject if client_id is missing (FastAPI should handle this, but double-check)
if not client_id:
await websocket.close(code=4008, reason="client_id query parameter required")
logger.warning("WebSocket connection rejected: missing client_id")
return
# Connect using manager (handles duplicate detection)
connected = await manager.connect(client_id, websocket)
if not connected:
await websocket.close(code=4009, reason="client_id already connected")
return
try:
while True:
# Receive and parse JSON message from client
try:
data = await websocket.receive_json()
logger.info(
f"Message from {client_id}",
extra={"client_id": client_id, "message_type": type(data).__name__}
)
# Echo back with acknowledgment (or broadcast as needed)
response = {"status": "received", "client_id": client_id, "data": data}
await websocket.send_json(response)
except json.JSONDecodeError as e:
logger.warning(
f"Invalid JSON from {client_id}: {e}",
extra={"client_id": client_id, "error": str(e)}
)
await websocket.send_json({"error": "invalid_json", "message": "Failed to parse JSON message"})
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected: {client_id}", extra={"client_id": client_id, "event": "disconnect"})
except Exception as e:
logger.error(
f"WebSocket error for {client_id}: {e}",
extra={"client_id": client_id, "error": str(e)},
exc_info=True
)
finally:
# Always clean up connection
manager.disconnect(client_id)
@app.websocket("/ws/agents")
async def websocket_agents(websocket: WebSocket):
"""WebSocket endpoint for real-time agent status updates."""
await websocket.accept()
logger.info("WebSocket client connected", extra={"endpoint": "/ws/agents"})
try:
while True:
# Send agent status every 2 seconds
await websocket.send_json({
"agents": [
{"name": "Adam", "status": "Thinking"},
{"name": "Eve", "status": "Writing"}
]
})
await asyncio.sleep(2)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected", extra={"endpoint": "/ws/agents"})
@app.websocket("/ws/state")
async def websocket_state(websocket: WebSocket):
"""
WebSocket endpoint for real-time state updates.
Replaces polling /api/state. Pushes state updates to connected clients
whenever the worker state changes. Initial state sent immediately on connect.
"""
await websocket.accept()
client_id = f"state-{id(websocket)}"
logger.info("State WebSocket client connected", extra={"endpoint": "/ws/state", "client_id": client_id})
# Send initial state immediately
try:
initial_state = {
"type": "state",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"server": {
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
"brain_loaded": _brain_loaded,
},
"worker": _worker_state,
}
}
await websocket.send_json(initial_state)
logger.info("Initial state sent", extra={"client_id": client_id})
except Exception as e:
logger.error(f"Failed to send initial state: {e}", extra={"client_id": client_id})
await websocket.close()
return
# Track this connection for broadcasts
manager.active_connections[client_id] = websocket
try:
while True:
# Send periodic state updates (every 2 seconds)
# This is a simple push mechanism - could be enhanced to only send on changes
state_update = {
"type": "state_update",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"server": {
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
"brain_loaded": _brain_loaded,
},
"worker": _worker_state,
}
}
await websocket.send_json(state_update)
await asyncio.sleep(2)
except WebSocketDisconnect:
logger.info("State WebSocket client disconnected", extra={"endpoint": "/ws/state", "client_id": client_id})
except Exception as e:
logger.error(f"State WebSocket error: {e}", extra={"endpoint": "/ws/state", "client_id": client_id}, exc_info=True)
finally:
manager.disconnect(client_id)
# Mount static files to serve frontend at root path
app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
# =============================================================================
# Main Entry Point
# =============================================================================
if __name__ == "__main__":
# Clear startup logging
logger.info(
"Cain's Space starting",
extra={
"event": "startup",
"port": 7860,
"host": "0.0.0.0",
"brain_loaded": _brain_loaded,
"startup_mode": os.environ.get("STARTUP_MODE", "unknown"),
},
)
logger.info("Binding to port 7860 on 0.0.0.0", extra={"event": "port_binding", "port": 7860})
# Run uvicorn with JSON access logging
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
access_log=True,
log_config=None, # Use our custom logging setup
)