| """Rate limiting utilities using the limits library.""" |
|
|
| import asyncio |
| from typing import ClassVar |
|
|
| from limits import RateLimitItem, parse |
| from limits.storage import MemoryStorage |
| from limits.strategies import MovingWindowRateLimiter |
|
|
|
|
| class RateLimiter: |
| """ |
| Async-compatible rate limiter using limits library. |
| |
| Uses moving window algorithm for smooth rate limiting. |
| """ |
|
|
| def __init__(self, rate: str) -> None: |
| """ |
| Initialize rate limiter. |
| |
| Args: |
| rate: Rate string like "3/second" or "10/second" |
| """ |
| self.rate = rate |
| self._storage = MemoryStorage() |
| self._limiter = MovingWindowRateLimiter(self._storage) |
| self._rate_limit: RateLimitItem = parse(rate) |
| self._identity = "default" |
|
|
| async def acquire(self, wait: bool = True) -> bool: |
| """ |
| Acquire permission to make a request. |
| |
| ASYNC-SAFE: Uses asyncio.sleep(), never time.sleep(). |
| The polling pattern allows other coroutines to run while waiting. |
| |
| Args: |
| wait: If True, wait until allowed. If False, return immediately. |
| |
| Returns: |
| True if allowed, False if not (only when wait=False) |
| """ |
| while True: |
| |
| if self._limiter.hit(self._rate_limit, self._identity): |
| return True |
|
|
| if not wait: |
| return False |
|
|
| |
| |
| |
| |
| await asyncio.sleep(0.01) |
|
|
| def reset(self) -> None: |
| """Reset the rate limiter (for testing).""" |
| self._storage.reset() |
|
|
|
|
| |
| _pubmed_limiter: RateLimiter | None = None |
|
|
|
|
| def get_pubmed_limiter(api_key: str | None = None) -> RateLimiter: |
| """ |
| Get the shared PubMed rate limiter. |
| |
| Rate depends on whether API key is provided: |
| - Without key: 3 requests/second |
| - With key: 10 requests/second |
| |
| Args: |
| api_key: NCBI API key (optional) |
| |
| Returns: |
| Shared RateLimiter instance |
| """ |
| global _pubmed_limiter |
|
|
| if _pubmed_limiter is None: |
| rate = "10/second" if api_key else "3/second" |
| _pubmed_limiter = RateLimiter(rate) |
|
|
| return _pubmed_limiter |
|
|
|
|
| def reset_pubmed_limiter() -> None: |
| """Reset the PubMed limiter (for testing).""" |
| global _pubmed_limiter |
| _pubmed_limiter = None |
|
|
|
|
| |
| class RateLimiterFactory: |
| """Factory for creating/getting rate limiters for different APIs.""" |
|
|
| _limiters: ClassVar[dict[str, RateLimiter]] = {} |
|
|
| @classmethod |
| def get(cls, api_name: str, rate: str) -> RateLimiter: |
| """ |
| Get or create a rate limiter for an API. |
| |
| Args: |
| api_name: Unique identifier for the API |
| rate: Rate limit string (e.g., "10/second") |
| |
| Returns: |
| RateLimiter instance (shared for same api_name) |
| """ |
| if api_name not in cls._limiters: |
| cls._limiters[api_name] = RateLimiter(rate) |
| return cls._limiters[api_name] |
|
|
| @classmethod |
| def reset_all(cls) -> None: |
| """Reset all limiters (for testing).""" |
| cls._limiters.clear() |
|
|