Spaces:
Build error
Build error
| from typing import Any | |
| import asyncio | |
| import redis.asyncio as aioredis | |
| import json | |
| from fastapi import WebSocket | |
| class RedisPubSubManager: | |
| """ | |
| Initializes the RedisPubSubManager. | |
| Args: | |
| host (str): Redis server host. | |
| port (int): Redis server port. | |
| """ | |
| def __init__(self, host='localhost', port=6379): | |
| self.redis_host = host | |
| self.redis_port = port | |
| self.pubsub = None | |
| async def _get_redis_connection(self) -> aioredis.Redis: | |
| """ | |
| Establishes a connection to Redis. | |
| Returns: | |
| aioredis.Redis: Redis connection object. | |
| """ | |
| return aioredis.Redis(host=self.redis_host, | |
| port=self.redis_port, | |
| auto_close_connection_pool=False) | |
| async def connect(self) -> None: | |
| """ | |
| Connects to the Redis server and initializes the pubsub client. | |
| """ | |
| self.redis_connection = await self._get_redis_connection() | |
| self.pubsub = self.redis_connection.pubsub() | |
| async def _publish(self, room_id: str, message: str) -> None: | |
| """ | |
| Publishes a message to a specific Redis channel. | |
| Args: | |
| room_id (str): Channel or room ID. | |
| message (str): Message to be published. | |
| """ | |
| await self.redis_connection.publish(room_id, message) | |
| async def subscribe(self, room_id: str) -> aioredis.Redis: | |
| """ | |
| Subscribes to a Redis channel. | |
| Args: | |
| room_id (str): Channel or room ID to subscribe to. | |
| Returns: | |
| aioredis.ChannelSubscribe: PubSub object for the subscribed channel. | |
| """ | |
| await self.pubsub.subscribe(room_id) | |
| return self.pubsub | |
| async def unsubscribe(self, room_id: str) -> None: | |
| """ | |
| Unsubscribes from a Redis channel. | |
| Args: | |
| room_id (str): Channel or room ID to unsubscribe from. | |
| """ | |
| await self.pubsub.unsubscribe(room_id) | |
| class WebSocketManager: | |
| def __init__(self): | |
| """ | |
| Initializes the WebSocketManager. | |
| Attributes: | |
| rooms (dict): A dictionary to store WebSocket connections in different rooms. | |
| pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality. | |
| """ | |
| self.rooms: dict = {} | |
| self.qa: dict = {} | |
| self.pubsub_client = RedisPubSubManager() | |
| async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: | |
| """ | |
| Adds a user's WebSocket connection to a room. | |
| Args: | |
| room_id (str): Room ID or channel name. | |
| websocket (WebSocket): WebSocket connection object. | |
| """ | |
| await websocket.accept() | |
| if room_id in self.rooms: | |
| self.rooms[room_id].append(websocket) | |
| else: | |
| self.rooms[room_id] = [websocket] | |
| await self.pubsub_client.connect() | |
| pubsub_subscriber = await self.pubsub_client.subscribe(room_id) | |
| asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) | |
| async def broadcast_to_room(self, room_id: str, message: str) -> None: | |
| """ | |
| Broadcasts a message to all connected WebSockets in a room. | |
| Args: | |
| room_id (str): Room ID or channel name. | |
| message (str): Message to be broadcasted. | |
| """ | |
| await self.pubsub_client._publish(room_id, message) | |
| async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: | |
| """ | |
| Removes a user's WebSocket connection from a room. | |
| Args: | |
| room_id (str): Room ID or channel name. | |
| websocket (WebSocket): WebSocket connection object. | |
| """ | |
| self.rooms[room_id].remove(websocket) | |
| self.qa.pop(room_id, None) | |
| if len(self.rooms[room_id]) == 0: | |
| del self.rooms[room_id] | |
| await self.pubsub_client.unsubscribe(room_id) | |
| async def _pubsub_data_reader(self, pubsub_subscriber): | |
| """ | |
| Reads and broadcasts messages received from Redis PubSub. | |
| Args: | |
| pubsub_subscriber (aioredis.ChannelSubscribe): PubSub object for the subscribed channel. | |
| """ | |
| while True: | |
| message = await pubsub_subscriber.get_message(ignore_subscribe_messages=True) | |
| if message is not None: | |
| room_id = message['channel'].decode('utf-8') | |
| all_sockets = self.rooms[room_id] | |
| for socket in all_sockets: | |
| data = message['data'].decode('utf-8') | |
| await socket.send_text(data) | |