| """ |
| WebSocket Chat Handler |
| ====================== |
| Handles real-time chat via WebSocket with streaming responses. |
| """ |
|
|
| import json |
| import asyncio |
| import logging |
| from typing import Optional |
|
|
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect |
| from eurus.config import CONFIG |
|
|
| router = APIRouter() |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ConnectionManager: |
| """Manages WebSocket connections.""" |
|
|
| def __init__(self): |
| self.active_connections: list[WebSocket] = [] |
|
|
| async def connect(self, websocket: WebSocket): |
| await websocket.accept() |
| self.active_connections.append(websocket) |
| logger.info(f"WebSocket connected. Total: {len(self.active_connections)}") |
|
|
| def disconnect(self, websocket: WebSocket): |
| if websocket in self.active_connections: |
| self.active_connections.remove(websocket) |
| logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}") |
|
|
| async def send_json(self, websocket: WebSocket, data: dict): |
| try: |
| await websocket.send_json(data) |
| except Exception as e: |
| logger.error(f"Failed to send message: {e}") |
|
|
|
|
| manager = ConnectionManager() |
|
|
|
|
| @router.websocket("/ws/chat") |
| async def websocket_chat(websocket: WebSocket): |
| """WebSocket endpoint for chat.""" |
| import uuid |
| connection_id = str(uuid.uuid4()) |
| |
| await manager.connect(websocket) |
| logger.info(f"New connection: {connection_id}") |
|
|
| try: |
| |
| from web.agent_wrapper import create_session, get_session, close_session |
| session = None |
|
|
| while True: |
| data = await websocket.receive_json() |
| message = data.get("message", "").strip() |
|
|
| |
| if data.get("type") == "configure_keys": |
| api_keys = { |
| "openai_api_key": data.get("openai_api_key", ""), |
| "arraylake_api_key": data.get("arraylake_api_key", ""), |
| "hf_token": data.get("hf_token", ""), |
| } |
| session = create_session(connection_id, api_keys=api_keys) |
| ready = session.is_ready() |
| await manager.send_json(websocket, { |
| "type": "keys_configured", |
| "ready": ready, |
| }) |
| continue |
|
|
| |
| if data.get("type") == "set_provider": |
| model_id = data.get("model", "gpt-5.2") |
| |
| if session is None: |
| session = create_session(connection_id) |
| session.set_provider(model_id) |
| await manager.send_json(websocket, { |
| "type": "provider_changed", |
| "model": model_id, |
| }) |
| continue |
|
|
| |
| if data.get("type") == "get_models": |
| from web.agent_wrapper import AgentSession |
| models = AgentSession.AVAILABLE_MODELS |
| current = session.get_current_model() if session else CONFIG.model_name |
| await manager.send_json(websocket, { |
| "type": "models_list", |
| "models": models, |
| "current": current, |
| }) |
| continue |
|
|
| |
| if session is None: |
| session = create_session(connection_id) |
|
|
| if not message: |
| continue |
|
|
| logger.info(f"[{connection_id[:8]}] Received: {message[:100]}...") |
|
|
| |
| if message.strip() == "/clear": |
| session = get_session(connection_id) |
| if session: |
| session.clear_messages() |
| await manager.send_json(websocket, {"type": "clear"}) |
| continue |
|
|
| |
| await manager.send_json(websocket, {"type": "thinking"}) |
|
|
| try: |
| |
| session = get_session(connection_id) |
| if not session: |
| logger.warning(f"Session lost for {connection_id[:8]}, requesting keys...") |
| |
| await manager.send_json(websocket, { |
| "type": "request_keys", |
| "reason": "Session expired, please reconnect." |
| }) |
| continue |
|
|
| |
| async def stream_callback(event_type: str, content: str, **kwargs): |
| msg = {"type": event_type, "content": content} |
| msg.update(kwargs) |
| await manager.send_json(websocket, msg) |
|
|
| |
| response = await session.process_message(message, stream_callback) |
|
|
| |
| await manager.send_json(websocket, { |
| "type": "complete", |
| }) |
|
|
| except Exception as e: |
| logger.exception(f"Error: {e}") |
| await manager.send_json(websocket, { |
| "type": "error", |
| "content": str(e) |
| }) |
|
|
| except WebSocketDisconnect: |
| logger.info(f"Connection {connection_id[:8]} disconnected") |
| manager.disconnect(websocket) |
| close_session(connection_id) |
| except Exception as e: |
| logger.exception(f"WebSocket error: {e}") |
| manager.disconnect(websocket) |
| close_session(connection_id) |
|
|