|
|
import aio_pika |
|
|
from typing import Optional |
|
|
from api.chat.chat_api import ChatAPI |
|
|
from .base import ChatQueueBase |
|
|
|
|
|
class RabbitMQQueue(ChatQueueBase): |
|
|
def __init__(self, url: str): |
|
|
self.url = url |
|
|
self.connection = None |
|
|
self.channel = None |
|
|
self.queue_name = "chat_queue" |
|
|
|
|
|
async def connect(self): |
|
|
if not self.connection: |
|
|
self.connection = await aio_pika.connect_robust(self.url) |
|
|
self.channel = await self.connection.channel() |
|
|
await self.channel.declare_queue(self.queue_name) |
|
|
|
|
|
async def add(self, api_key: str) -> None: |
|
|
await self.connect() |
|
|
message = aio_pika.Message(body=api_key.encode()) |
|
|
await self.channel.default_exchange.publish( |
|
|
message, |
|
|
routing_key=self.queue_name |
|
|
) |
|
|
|
|
|
async def get(self) -> Optional[ChatAPI]: |
|
|
await self.connect() |
|
|
message = await self.channel.get(self.queue_name, no_ack=True) |
|
|
if message: |
|
|
api_key = message.body.decode() |
|
|
chat = ChatAPI(api_key=api_key) |
|
|
|
|
|
await self.add(api_key) |
|
|
return chat |
|
|
return None |
|
|
|
|
|
async def length(self) -> int: |
|
|
await self.connect() |
|
|
queue = await self.channel.declare_queue(self.queue_name) |
|
|
return queue.declaration_result.message_count |