File size: 2,326 Bytes
f55f92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations

import asyncio
import time
from typing import Awaitable, Callable


class RequestRateLimiter:
    def __init__(
        self,
        global_interval_seconds: float,
        per_domain_interval_seconds: float,
        *,
        clock: Callable[[], float] | None = None,
        sleep: Callable[[float], Awaitable[None]] | None = None,
    ) -> None:
        self.global_interval_seconds = max(0.0, float(global_interval_seconds))
        self.per_domain_interval_seconds = max(0.0, float(per_domain_interval_seconds))
        self._clock = clock or time.monotonic
        self._sleep = sleep or asyncio.sleep

        self._global_lock = asyncio.Lock()
        self._global_last: float | None = None

        self._domain_guard = asyncio.Lock()
        self._domain_locks: dict[str, asyncio.Lock] = {}
        self._domain_last: dict[str, float] = {}

    async def acquire(self, domain: str) -> None:
        normalized = domain.lower().strip(".")
        await self._acquire_global()
        await self._acquire_domain(normalized)

    async def _acquire_global(self) -> None:
        if self.global_interval_seconds <= 0:
            return

        async with self._global_lock:
            now = self._clock()
            if self._global_last is not None:
                wait = self.global_interval_seconds - (now - self._global_last)
                if wait > 0:
                    await self._sleep(wait)
            self._global_last = self._clock()

    async def _acquire_domain(self, domain: str) -> None:
        if not domain or self.per_domain_interval_seconds <= 0:
            return

        lock = await self._get_domain_lock(domain)
        async with lock:
            now = self._clock()
            last = self._domain_last.get(domain)
            if last is not None:
                wait = self.per_domain_interval_seconds - (now - last)
                if wait > 0:
                    await self._sleep(wait)
            self._domain_last[domain] = self._clock()

    async def _get_domain_lock(self, domain: str) -> asyncio.Lock:
        async with self._domain_guard:
            lock = self._domain_locks.get(domain)
            if lock is None:
                lock = asyncio.Lock()
                self._domain_locks[domain] = lock
            return lock