File size: 2,779 Bytes
bf292d9
bbfbcdd
a3a1b05
ba71442
05be9a1
bf292d9
 
 
bbfbcdd
a3a1b05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf292d9
 
 
 
 
a3a1b05
 
0bf22fe
bf292d9
0bf22fe
bf292d9
 
05be9a1
a3a1b05
05be9a1
 
 
 
 
 
 
a3a1b05
bf292d9
 
 
 
 
 
 
a3a1b05
0bf22fe
 
 
bf292d9
 
 
0bf22fe
 
 
 
 
 
 
bf292d9
 
0bf22fe
bf292d9
 
0bf22fe
a3a1b05
 
 
 
 
0bf22fe
bf292d9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Callable, Dict, List, Optional
import aio_pika
from urllib.parse import urlsplit, unquote
from config import settings
import ssl  # ✅ correct import

ExchangeResolver = Callable[[str], str]  # exchangeName -> exchangeType


def _parse_amqp_url(url: str) -> dict:
    """
    Convert an AMQP URL into kwargs for aio_pika.connect_robust,
    so we don't pass the raw URL (and leak the password in logs).
    """
    parts = urlsplit(url)
    return {
        "host": parts.hostname or "localhost",
        "port": parts.port or (5671 if parts.scheme == "amqps" else 5672),
        "login": parts.username or "guest",
        "password": parts.password or "guest",
        "virtualhost": unquote(parts.path[1:] or "/"),
        "ssl": parts.scheme == "amqps",
    }


class RabbitBase:
    def __init__(self, exchange_type_resolver: Optional[ExchangeResolver] = None):
        self._conn: Optional[aio_pika.RobustConnection] = None
        self._chan: Optional[aio_pika.RobustChannel] = None
        self._exchanges: Dict[str, aio_pika.Exchange] = {}
        self._exchange_type_resolver = exchange_type_resolver or (
            lambda _: settings.RABBIT_EXCHANGE_TYPE
        )

    async def connect(self) -> None:
        if self._conn and not self._conn.is_closed:
            return

        conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))

        # Disable SSL verification if using TLS
        if conn_kwargs.get("ssl"):
            conn_kwargs["ssl_options"] = {
                "cert_reqs": ssl.CERT_NONE
            }

        self._conn = await aio_pika.connect_robust(**conn_kwargs)
        self._chan = await self._conn.channel()
        await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)

    async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
        await self.connect()
        if name in self._exchanges:
            return self._exchanges[name]
        ex_type = self._exchange_type_resolver(name)
        ex = await self._chan.declare_exchange(
            name, getattr(aio_pika.ExchangeType, ex_type), durable=True
        )
        self._exchanges[name] = ex
        return ex

    async def declare_queue_bind(
        self,
        exchange: str,
        queue_name: str,
        routing_keys: List[str],
        ttl_ms: Optional[int],
    ):
        await self.connect()
        ex = await self.ensure_exchange(exchange)
        args: Dict[str, int] = {}
        if ttl_ms:
            args["x-message-ttl"] = ttl_ms
        q = await self._chan.declare_queue(
            queue_name,
            durable=True,
            exclusive=False,
            auto_delete=True,
            arguments=args,
        )
        for rk in routing_keys or [""]:
            await q.bind(ex, rk)
        return q