|
|
"""Rate limiting utilities using the limits library.""" |
|
|
|
|
|
import asyncio |
|
|
from typing import ClassVar |
|
|
|
|
|
from limits import RateLimitItem, parse |
|
|
from limits.storage import MemoryStorage |
|
|
from limits.strategies import MovingWindowRateLimiter |
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
""" |
|
|
Async-compatible rate limiter using limits library. |
|
|
|
|
|
Uses moving window algorithm for smooth rate limiting. |
|
|
""" |
|
|
|
|
|
def __init__(self, rate: str) -> None: |
|
|
""" |
|
|
Initialize rate limiter. |
|
|
|
|
|
Args: |
|
|
rate: Rate string like "3/second" or "10/second" |
|
|
""" |
|
|
self.rate = rate |
|
|
self._storage = MemoryStorage() |
|
|
self._limiter = MovingWindowRateLimiter(self._storage) |
|
|
self._rate_limit: RateLimitItem = parse(rate) |
|
|
self._identity = "default" |
|
|
|
|
|
async def acquire(self, wait: bool = True) -> bool: |
|
|
""" |
|
|
Acquire permission to make a request. |
|
|
|
|
|
ASYNC-SAFE: Uses asyncio.sleep(), never time.sleep(). |
|
|
The polling pattern allows other coroutines to run while waiting. |
|
|
|
|
|
Args: |
|
|
wait: If True, wait until allowed. If False, return immediately. |
|
|
|
|
|
Returns: |
|
|
True if allowed, False if not (only when wait=False) |
|
|
""" |
|
|
while True: |
|
|
|
|
|
if self._limiter.hit(self._rate_limit, self._identity): |
|
|
return True |
|
|
|
|
|
if not wait: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
def reset(self) -> None: |
|
|
"""Reset the rate limiter (for testing).""" |
|
|
self._storage.reset() |
|
|
|
|
|
|
|
|
|
|
|
_pubmed_limiter: RateLimiter | None = None |
|
|
|
|
|
|
|
|
def get_pubmed_limiter(api_key: str | None = None) -> RateLimiter: |
|
|
""" |
|
|
Get the shared PubMed rate limiter. |
|
|
|
|
|
Rate depends on whether API key is provided: |
|
|
- Without key: 3 requests/second |
|
|
- With key: 10 requests/second |
|
|
|
|
|
Args: |
|
|
api_key: NCBI API key (optional) |
|
|
|
|
|
Returns: |
|
|
Shared RateLimiter instance |
|
|
""" |
|
|
global _pubmed_limiter |
|
|
|
|
|
if _pubmed_limiter is None: |
|
|
rate = "10/second" if api_key else "3/second" |
|
|
_pubmed_limiter = RateLimiter(rate) |
|
|
|
|
|
return _pubmed_limiter |
|
|
|
|
|
|
|
|
def reset_pubmed_limiter() -> None: |
|
|
"""Reset the PubMed limiter (for testing).""" |
|
|
global _pubmed_limiter |
|
|
_pubmed_limiter = None |
|
|
|
|
|
|
|
|
|
|
|
class RateLimiterFactory: |
|
|
"""Factory for creating/getting rate limiters for different APIs.""" |
|
|
|
|
|
_limiters: ClassVar[dict[str, RateLimiter]] = {} |
|
|
|
|
|
@classmethod |
|
|
def get(cls, api_name: str, rate: str) -> RateLimiter: |
|
|
""" |
|
|
Get or create a rate limiter for an API. |
|
|
|
|
|
Args: |
|
|
api_name: Unique identifier for the API |
|
|
rate: Rate limit string (e.g., "10/second") |
|
|
|
|
|
Returns: |
|
|
RateLimiter instance (shared for same api_name) |
|
|
""" |
|
|
if api_name not in cls._limiters: |
|
|
cls._limiters[api_name] = RateLimiter(rate) |
|
|
return cls._limiters[api_name] |
|
|
|
|
|
@classmethod |
|
|
def reset_all(cls) -> None: |
|
|
"""Reset all limiters (for testing).""" |
|
|
cls._limiters.clear() |
|
|
|