| """ |
| WebSocket Service |
| Handles real-time data updates to connected clients |
| """ |
| import asyncio |
| import json |
| import logging |
| from typing import Dict, Set, Any, List, Optional |
| from datetime import datetime |
| from fastapi import WebSocket, WebSocketDisconnect |
| from collections import defaultdict |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ConnectionManager: |
| """Manages WebSocket connections and broadcasts""" |
|
|
| def __init__(self): |
| |
| self.active_connections: Dict[str, WebSocket] = {} |
|
|
| |
| self.subscriptions: Dict[str, Set[str]] = defaultdict(set) |
|
|
| |
| self.client_subscriptions: Dict[str, Set[str]] = defaultdict(set) |
|
|
| |
| self.connection_metadata: Dict[str, Dict[str, Any]] = {} |
|
|
| async def connect(self, websocket: WebSocket, client_id: str, metadata: Optional[Dict] = None): |
| """ |
| Connect a new WebSocket client |
| |
| Args: |
| websocket: WebSocket connection |
| client_id: Unique client identifier |
| metadata: Optional metadata about the connection |
| """ |
| await websocket.accept() |
| self.active_connections[client_id] = websocket |
| self.connection_metadata[client_id] = metadata or {} |
|
|
| logger.info(f"Client {client_id} connected. Total connections: {len(self.active_connections)}") |
|
|
| def disconnect(self, client_id: str): |
| """ |
| Disconnect a WebSocket client |
| |
| Args: |
| client_id: Client identifier |
| """ |
| if client_id in self.active_connections: |
| del self.active_connections[client_id] |
|
|
| |
| for api_id in self.client_subscriptions.get(client_id, set()).copy(): |
| self.unsubscribe(client_id, api_id) |
|
|
| if client_id in self.client_subscriptions: |
| del self.client_subscriptions[client_id] |
|
|
| if client_id in self.connection_metadata: |
| del self.connection_metadata[client_id] |
|
|
| logger.info(f"Client {client_id} disconnected. Total connections: {len(self.active_connections)}") |
|
|
| def subscribe(self, client_id: str, api_id: str): |
| """ |
| Subscribe a client to API updates |
| |
| Args: |
| client_id: Client identifier |
| api_id: API identifier to subscribe to |
| """ |
| self.subscriptions[api_id].add(client_id) |
| self.client_subscriptions[client_id].add(api_id) |
|
|
| logger.debug(f"Client {client_id} subscribed to {api_id}") |
|
|
| def unsubscribe(self, client_id: str, api_id: str): |
| """ |
| Unsubscribe a client from API updates |
| |
| Args: |
| client_id: Client identifier |
| api_id: API identifier to unsubscribe from |
| """ |
| if api_id in self.subscriptions: |
| self.subscriptions[api_id].discard(client_id) |
|
|
| |
| if not self.subscriptions[api_id]: |
| del self.subscriptions[api_id] |
|
|
| if client_id in self.client_subscriptions: |
| self.client_subscriptions[client_id].discard(api_id) |
|
|
| logger.debug(f"Client {client_id} unsubscribed from {api_id}") |
|
|
| def subscribe_all(self, client_id: str): |
| """ |
| Subscribe a client to all API updates |
| |
| Args: |
| client_id: Client identifier |
| """ |
| self.client_subscriptions[client_id].add('*') |
| logger.debug(f"Client {client_id} subscribed to all updates") |
|
|
| async def send_personal_message(self, message: Dict[str, Any], client_id: str): |
| """ |
| Send a message to a specific client |
| |
| Args: |
| message: Message data |
| client_id: Target client identifier |
| """ |
| if client_id in self.active_connections: |
| websocket = self.active_connections[client_id] |
| try: |
| await websocket.send_json(message) |
| except Exception as e: |
| logger.error(f"Error sending message to {client_id}: {e}") |
| self.disconnect(client_id) |
|
|
| async def broadcast(self, message: Dict[str, Any], api_id: Optional[str] = None): |
| """ |
| Broadcast a message to subscribed clients |
| |
| Args: |
| message: Message data |
| api_id: Optional API ID (broadcasts to all if None) |
| """ |
| if api_id: |
| |
| target_clients = self.subscriptions.get(api_id, set()) |
|
|
| |
| target_clients = target_clients.union( |
| {cid for cid, subs in self.client_subscriptions.items() if '*' in subs} |
| ) |
| else: |
| |
| target_clients = set(self.active_connections.keys()) |
|
|
| |
| disconnected_clients = [] |
|
|
| for client_id in target_clients: |
| if client_id in self.active_connections: |
| websocket = self.active_connections[client_id] |
| try: |
| await websocket.send_json(message) |
| except Exception as e: |
| logger.error(f"Error broadcasting to {client_id}: {e}") |
| disconnected_clients.append(client_id) |
|
|
| |
| for client_id in disconnected_clients: |
| self.disconnect(client_id) |
|
|
| async def broadcast_api_update(self, api_id: str, data: Dict[str, Any], metadata: Optional[Dict] = None): |
| """ |
| Broadcast an API data update |
| |
| Args: |
| api_id: API identifier |
| data: Updated data |
| metadata: Optional metadata about the update |
| """ |
| message = { |
| 'type': 'api_update', |
| 'api_id': api_id, |
| 'data': data, |
| 'metadata': metadata or {}, |
| 'timestamp': datetime.now().isoformat() |
| } |
|
|
| await self.broadcast(message, api_id) |
|
|
| async def broadcast_status_update(self, status: Dict[str, Any]): |
| """ |
| Broadcast a system status update |
| |
| Args: |
| status: Status data |
| """ |
| message = { |
| 'type': 'status_update', |
| 'status': status, |
| 'timestamp': datetime.now().isoformat() |
| } |
|
|
| await self.broadcast(message) |
|
|
| async def broadcast_schedule_update(self, schedule_info: Dict[str, Any]): |
| """ |
| Broadcast a schedule update |
| |
| Args: |
| schedule_info: Schedule information |
| """ |
| message = { |
| 'type': 'schedule_update', |
| 'schedule': schedule_info, |
| 'timestamp': datetime.now().isoformat() |
| } |
|
|
| await self.broadcast(message) |
|
|
| def get_connection_stats(self) -> Dict[str, Any]: |
| """ |
| Get connection statistics |
| |
| Returns: |
| Statistics about connections and subscriptions |
| """ |
| return { |
| 'total_connections': len(self.active_connections), |
| 'total_subscriptions': sum(len(subs) for subs in self.subscriptions.values()), |
| 'apis_with_subscribers': len(self.subscriptions), |
| 'clients': { |
| client_id: { |
| 'subscriptions': list(self.client_subscriptions.get(client_id, set())), |
| 'metadata': self.connection_metadata.get(client_id, {}) |
| } |
| for client_id in self.active_connections.keys() |
| } |
| } |
|
|
|
|
| class WebSocketService: |
| """WebSocket service for real-time updates""" |
|
|
| def __init__(self, scheduler_service=None, persistence_service=None): |
| self.connection_manager = ConnectionManager() |
| self.scheduler_service = scheduler_service |
| self.persistence_service = persistence_service |
| self.running = False |
|
|
| |
| if self.scheduler_service: |
| self._register_scheduler_callbacks() |
|
|
| def _register_scheduler_callbacks(self): |
| """Register callbacks with the scheduler service""" |
| |
| |
| pass |
|
|
| async def handle_client_message(self, websocket: WebSocket, client_id: str, message: Dict[str, Any]): |
| """ |
| Handle incoming messages from clients |
| |
| Args: |
| websocket: WebSocket connection |
| client_id: Client identifier |
| message: Message from client |
| """ |
| try: |
| message_type = message.get('type') |
|
|
| if message_type == 'subscribe': |
| |
| api_id = message.get('api_id') |
| if api_id: |
| self.connection_manager.subscribe(client_id, api_id) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'subscribed', |
| 'api_id': api_id, |
| 'status': 'success' |
| }, client_id) |
|
|
| elif message_type == 'subscribe_all': |
| |
| self.connection_manager.subscribe_all(client_id) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'subscribed', |
| 'api_id': '*', |
| 'status': 'success' |
| }, client_id) |
|
|
| elif message_type == 'unsubscribe': |
| |
| api_id = message.get('api_id') |
| if api_id: |
| self.connection_manager.unsubscribe(client_id, api_id) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'unsubscribed', |
| 'api_id': api_id, |
| 'status': 'success' |
| }, client_id) |
|
|
| elif message_type == 'get_data': |
| |
| api_id = message.get('api_id') |
| if api_id and self.persistence_service: |
| data = self.persistence_service.get_cached_data(api_id) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'data_response', |
| 'api_id': api_id, |
| 'data': data |
| }, client_id) |
|
|
| elif message_type == 'get_all_data': |
| |
| if self.persistence_service: |
| data = self.persistence_service.get_all_cached_data() |
| await self.connection_manager.send_personal_message({ |
| 'type': 'data_response', |
| 'data': data |
| }, client_id) |
|
|
| elif message_type == 'get_schedule': |
| |
| if self.scheduler_service: |
| schedules = self.scheduler_service.get_all_task_statuses() |
| await self.connection_manager.send_personal_message({ |
| 'type': 'schedule_response', |
| 'schedules': schedules |
| }, client_id) |
|
|
| elif message_type == 'update_schedule': |
| |
| api_id = message.get('api_id') |
| interval = message.get('interval') |
| enabled = message.get('enabled') |
|
|
| if api_id and self.scheduler_service: |
| self.scheduler_service.update_task_schedule(api_id, interval, enabled) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'schedule_updated', |
| 'api_id': api_id, |
| 'status': 'success' |
| }, client_id) |
|
|
| elif message_type == 'force_update': |
| |
| api_id = message.get('api_id') |
| if api_id and self.scheduler_service: |
| success = await self.scheduler_service.force_update(api_id) |
| await self.connection_manager.send_personal_message({ |
| 'type': 'update_result', |
| 'api_id': api_id, |
| 'status': 'success' if success else 'failed' |
| }, client_id) |
|
|
| elif message_type == 'ping': |
| |
| await self.connection_manager.send_personal_message({ |
| 'type': 'pong', |
| 'timestamp': datetime.now().isoformat() |
| }, client_id) |
|
|
| else: |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") |
|
|
| except Exception as e: |
| logger.error(f"Error handling client message: {e}") |
| await self.connection_manager.send_personal_message({ |
| 'type': 'error', |
| 'message': str(e) |
| }, client_id) |
|
|
| async def notify_data_update(self, api_id: str, data: Dict[str, Any], metadata: Optional[Dict] = None): |
| """ |
| Notify clients about data updates |
| |
| Args: |
| api_id: API identifier |
| data: Updated data |
| metadata: Optional metadata |
| """ |
| await self.connection_manager.broadcast_api_update(api_id, data, metadata) |
|
|
| async def notify_status_update(self, status: Dict[str, Any]): |
| """ |
| Notify clients about status updates |
| |
| Args: |
| status: Status information |
| """ |
| await self.connection_manager.broadcast_status_update(status) |
|
|
| async def notify_schedule_update(self, schedule_info: Dict[str, Any]): |
| """ |
| Notify clients about schedule updates |
| |
| Args: |
| schedule_info: Schedule information |
| """ |
| await self.connection_manager.broadcast_schedule_update(schedule_info) |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """Get WebSocket service statistics""" |
| return self.connection_manager.get_connection_stats() |
|
|
|
|
| |
| websocket_service = WebSocketService() |
|
|