zyon-traders-backend / services /websocket_manager.py
zyonpackers's picture
Update services/websocket_manager.py
38e3ad0 verified
"""
WebSocket Manager for Zyon Traders Backend
Handles real-time data distribution to connected clients
"""
import asyncio
import json
import logging
from typing import Dict, List, Set
from fastapi import WebSocket
from datetime import datetime
import random
import os
from config.huggingface import get_huggingface_config, get_optimized_websocket_settings
logger = logging.getLogger(__name__)
class WebSocketManager:
"""Manages WebSocket connections and real-time data distribution"""
def __init__(self):
self.active_connections: List[WebSocket] = []
self.subscriptions: Dict[WebSocket, Set[str]] = {}
self.background_task: asyncio.Task = None
self.running = False
# Load HuggingFace-optimized configuration
self.config = get_huggingface_config()
self.ws_settings = get_optimized_websocket_settings()
self.max_connections = self.config['max_connections']
logger.info(f"WebSocket manager initialized for {'HuggingFace' if self.config['is_huggingface'] else 'standard'} environment")
logger.info(f"Max connections: {self.max_connections}")
async def connect(self, websocket: WebSocket):
"""Accept new WebSocket connection with rate limiting"""
# Check connection limit for HuggingFace
if len(self.active_connections) >= self.max_connections:
logger.warning(f"Connection limit reached ({self.max_connections})")
await websocket.close(code=1013, reason="Server overloaded")
return False
try:
await websocket.accept()
self.active_connections.append(websocket)
self.subscriptions[websocket] = set()
logger.info(f"New WebSocket connection. Total: {len(self.active_connections)}/{self.max_connections}")
return True
except Exception as e:
logger.error(f"Failed to accept WebSocket connection: {e}")
return False
def disconnect(self, websocket: WebSocket):
"""Remove WebSocket connection"""
if websocket in self.active_connections:
self.active_connections.remove(websocket)
if websocket in self.subscriptions:
del self.subscriptions[websocket]
logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}")
async def subscribe_to_symbols(self, websocket: WebSocket, symbols: List[str]):
"""Subscribe to specific symbols for real-time updates"""
if websocket in self.subscriptions:
self.subscriptions[websocket].update(symbols)
await websocket.send_text(json.dumps({
"type": "subscription_success",
"symbols": symbols,
"timestamp": datetime.utcnow().isoformat()
}))
logger.info(f"WebSocket subscribed to symbols: {symbols}")
async def unsubscribe_from_symbols(self, websocket: WebSocket, symbols: List[str]):
"""Unsubscribe from specific symbols"""
if websocket in self.subscriptions:
self.subscriptions[websocket].difference_update(symbols)
await websocket.send_text(json.dumps({
"type": "unsubscription_success",
"symbols": symbols,
"timestamp": datetime.utcnow().isoformat()
}))
logger.info(f"WebSocket unsubscribed from symbols: {symbols}")
async def broadcast_to_symbol_subscribers(self, symbol: str, data: dict):
"""Broadcast data to all clients subscribed to a specific symbol"""
message = json.dumps({
"type": "market_data",
"symbol": symbol,
"data": data,
"timestamp": datetime.utcnow().isoformat()
})
disconnected = []
for websocket in self.active_connections:
if symbol in self.subscriptions.get(websocket, set()):
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"Error sending data to WebSocket: {e}")
disconnected.append(websocket)
# Clean up disconnected clients
for websocket in disconnected:
self.disconnect(websocket)
async def broadcast_to_all(self, data: dict):
"""Broadcast data to all connected clients"""
message = json.dumps({
"type": "broadcast",
"data": data,
"timestamp": datetime.utcnow().isoformat()
})
disconnected = []
for websocket in self.active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"Error broadcasting to WebSocket: {e}")
disconnected.append(websocket)
# Clean up disconnected clients
for websocket in disconnected:
self.disconnect(websocket)
async def simulate_market_data(self):
"""Simulate real-time market data updates"""
symbols = ["NIFTY", "SENSEX", "BANKNIFTY", "RELIANCE", "TCS", "INFY", "HDFC"]
while self.running:
try:
# Simulate price updates
for symbol in symbols:
# Generate realistic price movement
base_price = {
"NIFTY": 22000,
"SENSEX": 72000,
"BANKNIFTY": 47000,
"RELIANCE": 2500,
"TCS": 3800,
"INFY": 1700,
"HDFC": 1600
}.get(symbol, 1000)
change_percent = random.uniform(-2, 2)
price = base_price * (1 + change_percent / 100)
data = {
"price": round(price, 2),
"change": round(base_price * change_percent / 100, 2),
"change_percent": round(change_percent, 2),
"volume": random.randint(100000, 10000000),
"last_updated": datetime.utcnow().isoformat()
}
await self.broadcast_to_symbol_subscribers(symbol, data)
# Wait before next update
await asyncio.sleep(2) # Update every 2 seconds
except Exception as e:
logger.error(f"Error in market data simulation: {e}")
await asyncio.sleep(5)
async def start_background_tasks(self):
"""Start background tasks for data simulation"""
self.running = True
self.background_task = asyncio.create_task(self.simulate_market_data())
logger.info("WebSocket background tasks started")
async def cleanup(self):
"""Cleanup resources"""
self.running = False
if self.background_task:
self.background_task.cancel()
try:
await self.background_task
except asyncio.CancelledError:
pass
# Close all connections
for websocket in self.active_connections:
try:
await websocket.close()
except Exception:
pass
self.active_connections.clear()
self.subscriptions.clear()
logger.info("WebSocket manager cleaned up")
async def _send_message_safely(self, websocket: WebSocket, message: str):
"""Send message with timeout and error handling"""
try:
# Add timeout for HuggingFace
timeout = self.config['websocket_timeout']
await asyncio.wait_for(websocket.send_text(message), timeout=timeout)
except asyncio.TimeoutError:
logger.warning("WebSocket send timeout")
raise
except Exception as e:
logger.error(f"Error sending data to WebSocket: {e}")
raise