Spaces:
Runtime error
Runtime error
| """ | |
| WebSocket Manager for Zyon Traders Backend | |
| Handles real-time data distribution to connected clients | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| from typing import Dict, List, Set | |
| from fastapi import WebSocket | |
| from datetime import datetime | |
| import random | |
| import os | |
| from config.huggingface import get_huggingface_config, get_optimized_websocket_settings | |
| logger = logging.getLogger(__name__) | |
| class WebSocketManager: | |
| """Manages WebSocket connections and real-time data distribution""" | |
| def __init__(self): | |
| self.active_connections: List[WebSocket] = [] | |
| self.subscriptions: Dict[WebSocket, Set[str]] = {} | |
| self.background_task: asyncio.Task = None | |
| self.running = False | |
| # Load HuggingFace-optimized configuration | |
| self.config = get_huggingface_config() | |
| self.ws_settings = get_optimized_websocket_settings() | |
| self.max_connections = self.config['max_connections'] | |
| logger.info(f"WebSocket manager initialized for {'HuggingFace' if self.config['is_huggingface'] else 'standard'} environment") | |
| logger.info(f"Max connections: {self.max_connections}") | |
| async def connect(self, websocket: WebSocket): | |
| """Accept new WebSocket connection with rate limiting""" | |
| # Check connection limit for HuggingFace | |
| if len(self.active_connections) >= self.max_connections: | |
| logger.warning(f"Connection limit reached ({self.max_connections})") | |
| await websocket.close(code=1013, reason="Server overloaded") | |
| return False | |
| try: | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| self.subscriptions[websocket] = set() | |
| logger.info(f"New WebSocket connection. Total: {len(self.active_connections)}/{self.max_connections}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to accept WebSocket connection: {e}") | |
| return False | |
| def disconnect(self, websocket: WebSocket): | |
| """Remove WebSocket connection""" | |
| if websocket in self.active_connections: | |
| self.active_connections.remove(websocket) | |
| if websocket in self.subscriptions: | |
| del self.subscriptions[websocket] | |
| logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}") | |
| async def subscribe_to_symbols(self, websocket: WebSocket, symbols: List[str]): | |
| """Subscribe to specific symbols for real-time updates""" | |
| if websocket in self.subscriptions: | |
| self.subscriptions[websocket].update(symbols) | |
| await websocket.send_text(json.dumps({ | |
| "type": "subscription_success", | |
| "symbols": symbols, | |
| "timestamp": datetime.utcnow().isoformat() | |
| })) | |
| logger.info(f"WebSocket subscribed to symbols: {symbols}") | |
| async def unsubscribe_from_symbols(self, websocket: WebSocket, symbols: List[str]): | |
| """Unsubscribe from specific symbols""" | |
| if websocket in self.subscriptions: | |
| self.subscriptions[websocket].difference_update(symbols) | |
| await websocket.send_text(json.dumps({ | |
| "type": "unsubscription_success", | |
| "symbols": symbols, | |
| "timestamp": datetime.utcnow().isoformat() | |
| })) | |
| logger.info(f"WebSocket unsubscribed from symbols: {symbols}") | |
| async def broadcast_to_symbol_subscribers(self, symbol: str, data: dict): | |
| """Broadcast data to all clients subscribed to a specific symbol""" | |
| message = json.dumps({ | |
| "type": "market_data", | |
| "symbol": symbol, | |
| "data": data, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| disconnected = [] | |
| for websocket in self.active_connections: | |
| if symbol in self.subscriptions.get(websocket, set()): | |
| try: | |
| await websocket.send_text(message) | |
| except Exception as e: | |
| logger.error(f"Error sending data to WebSocket: {e}") | |
| disconnected.append(websocket) | |
| # Clean up disconnected clients | |
| for websocket in disconnected: | |
| self.disconnect(websocket) | |
| async def broadcast_to_all(self, data: dict): | |
| """Broadcast data to all connected clients""" | |
| message = json.dumps({ | |
| "type": "broadcast", | |
| "data": data, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| disconnected = [] | |
| for websocket in self.active_connections: | |
| try: | |
| await websocket.send_text(message) | |
| except Exception as e: | |
| logger.error(f"Error broadcasting to WebSocket: {e}") | |
| disconnected.append(websocket) | |
| # Clean up disconnected clients | |
| for websocket in disconnected: | |
| self.disconnect(websocket) | |
| async def simulate_market_data(self): | |
| """Simulate real-time market data updates""" | |
| symbols = ["NIFTY", "SENSEX", "BANKNIFTY", "RELIANCE", "TCS", "INFY", "HDFC"] | |
| while self.running: | |
| try: | |
| # Simulate price updates | |
| for symbol in symbols: | |
| # Generate realistic price movement | |
| base_price = { | |
| "NIFTY": 22000, | |
| "SENSEX": 72000, | |
| "BANKNIFTY": 47000, | |
| "RELIANCE": 2500, | |
| "TCS": 3800, | |
| "INFY": 1700, | |
| "HDFC": 1600 | |
| }.get(symbol, 1000) | |
| change_percent = random.uniform(-2, 2) | |
| price = base_price * (1 + change_percent / 100) | |
| data = { | |
| "price": round(price, 2), | |
| "change": round(base_price * change_percent / 100, 2), | |
| "change_percent": round(change_percent, 2), | |
| "volume": random.randint(100000, 10000000), | |
| "last_updated": datetime.utcnow().isoformat() | |
| } | |
| await self.broadcast_to_symbol_subscribers(symbol, data) | |
| # Wait before next update | |
| await asyncio.sleep(2) # Update every 2 seconds | |
| except Exception as e: | |
| logger.error(f"Error in market data simulation: {e}") | |
| await asyncio.sleep(5) | |
| async def start_background_tasks(self): | |
| """Start background tasks for data simulation""" | |
| self.running = True | |
| self.background_task = asyncio.create_task(self.simulate_market_data()) | |
| logger.info("WebSocket background tasks started") | |
| async def cleanup(self): | |
| """Cleanup resources""" | |
| self.running = False | |
| if self.background_task: | |
| self.background_task.cancel() | |
| try: | |
| await self.background_task | |
| except asyncio.CancelledError: | |
| pass | |
| # Close all connections | |
| for websocket in self.active_connections: | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass | |
| self.active_connections.clear() | |
| self.subscriptions.clear() | |
| logger.info("WebSocket manager cleaned up") | |
| async def _send_message_safely(self, websocket: WebSocket, message: str): | |
| """Send message with timeout and error handling""" | |
| try: | |
| # Add timeout for HuggingFace | |
| timeout = self.config['websocket_timeout'] | |
| await asyncio.wait_for(websocket.send_text(message), timeout=timeout) | |
| except asyncio.TimeoutError: | |
| logger.warning("WebSocket send timeout") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error sending data to WebSocket: {e}") | |
| raise |