Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| from fastapi import HTTPException, WebSocket, status | |
| from typing import Dict | |
| class InRequest: | |
| def __init__(self): | |
| self.responses: Dict[str, asyncio.Future] = {} | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.available = None | |
| self.active_connections: Dict[str, WebSocket] = {} # Maps socket ID to WebSocket connection | |
| self.in_request: Dict[str, InRequest] = {} # Store pending response futures | |
| async def connect(self, socket_id: str, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections[socket_id] = websocket | |
| if self.available is None: | |
| self.available = socket_id | |
| return socket_id | |
| def disconnect(self, socket_id: str): | |
| if socket_id in self.active_connections: | |
| del self.active_connections[socket_id] | |
| if self.available == socket_id: | |
| self.available = None | |
| async def broadcast(self, message: str): | |
| for connection in self.active_connections.values(): | |
| await connection.send_text(message) | |
| async def receive_text(self, socket_id: str): | |
| websocket = self.active_connections.get(socket_id) | |
| if websocket: | |
| return await websocket.receive_text() | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=f"Socket ID {socket_id} not connected") | |
| async def send_text(self, socket_id: str, message: str): | |
| websocket = self.active_connections.get(socket_id) | |
| if websocket: | |
| await websocket.send_text(message) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail="WebSocket connection not found.") | |
| async def send_bytes(self, socket_id: str, binary_data: bytes): | |
| websocket = self.active_connections.get(socket_id) | |
| if websocket: | |
| await websocket.send_bytes(binary_data) # Send binary data | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=f"Socket ID {socket_id} not connected") | |
| async def listen(self, socket_id:str, request_id:str) -> str: | |
| req = InRequest() | |
| # Create a Future for waiting for the response | |
| future = asyncio.get_event_loop().create_future() | |
| req.responses[request_id] = future | |
| self.in_request[socket_id] = req | |
| try: | |
| return await future # Await the future until it's set with a response | |
| except asyncio.CancelledError: | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=f"Socket ID {socket_id} not connected or canceled") | |
| async def notify(self, socket_id: str, message: str): | |
| logging.debug(message) | |
| # If there is a pending future for this socket, set the result | |
| if socket_id in self.in_request: | |
| request_id, payload = self.extract_message(message) | |
| if request_id is not None: | |
| self.in_request[socket_id].responses[request_id].set_result(payload) | |
| self.in_request.pop(socket_id, None) | |
| def extract_message(self, message:str): | |
| request_id = None | |
| payload = None | |
| logging.debug(message) | |
| try: | |
| o = json.loads(message) | |
| if o is not None: | |
| request_id, payload = o.get('request_id'), o.get('payload') | |
| except Exception as e: | |
| logging.warning(f"extract_message error: {str(e)}") | |
| return request_id, payload |