Spaces:
Running
Running
| """Global rate limiter for API requests.""" | |
| import asyncio | |
| import random | |
| import time | |
| from collections.abc import AsyncIterator, Callable | |
| from contextlib import asynccontextmanager | |
| from typing import Any, ClassVar, TypeVar | |
| import httpx | |
| import openai | |
| from loguru import logger | |
| from core.rate_limit import StrictSlidingWindowLimiter | |
| T = TypeVar("T") | |
| class AdaptiveRateLimiter: | |
| """Adaptive rate limiter that backs off on 429s and recovers gradually. | |
| Starts at a high throughput and auto-adjusts based on upstream feedback. | |
| This gives maximum throughput in normal conditions while self-correcting | |
| when rate limits are hit. | |
| """ | |
| _limiter_count: ClassVar[int] = 0 | |
| def __init__( | |
| self, | |
| initial_rate: int = 100, | |
| min_rate: int = 10, | |
| window: float = 60.0, | |
| backoff_factor: float = 0.5, | |
| recovery_factor: float = 1.2, | |
| ) -> None: | |
| self._initial_rate = initial_rate | |
| self._current_rate = initial_rate | |
| self._min_rate = min_rate | |
| self._window = window | |
| self._backoff_factor = backoff_factor | |
| self._recovery_factor = recovery_factor | |
| self._limiter = StrictSlidingWindowLimiter(initial_rate, window) | |
| self._lock = asyncio.Lock() | |
| self._success_streak: int = 0 | |
| self._instance_id = AdaptiveRateLimiter._limiter_count | |
| AdaptiveRateLimiter._limiter_count += 1 | |
| async def acquire(self) -> None: | |
| await self._limiter.acquire() | |
| def record_429(self) -> None: | |
| """Called when a 429 is received — reduce rate immediately.""" | |
| self._current_rate = max( | |
| self._min_rate, int(self._current_rate * self._backoff_factor) | |
| ) | |
| self._limiter = StrictSlidingWindowLimiter(self._current_rate, self._window) | |
| self._success_streak = 0 | |
| logger.warning( | |
| "ADAPTIVE_RATE: instance={} backed off to {} req/min (429 received)", | |
| self._instance_id, | |
| self._current_rate, | |
| ) | |
| def record_success(self) -> None: | |
| """Called on success — gradually recover rate if below initial.""" | |
| if self._current_rate >= self._initial_rate: | |
| self._success_streak = 0 | |
| return | |
| self._success_streak += 1 | |
| # Recover after 3 consecutive successes | |
| if self._success_streak >= 3: | |
| self._current_rate = min( | |
| self._initial_rate, | |
| int(self._current_rate * self._recovery_factor), | |
| ) | |
| self._limiter = StrictSlidingWindowLimiter(self._current_rate, self._window) | |
| self._success_streak = 0 | |
| logger.info( | |
| "ADAPTIVE_RATE: instance={} recovered to {} req/min", | |
| self._instance_id, | |
| self._current_rate, | |
| ) | |
| class ModelHealthTracker: | |
| """Track per-model health based on recent failures.""" | |
| _instance: ClassVar[ModelHealthTracker | None] = None | |
| def __init__( | |
| self, | |
| failure_ttl: float = 30.0, | |
| max_failures: int = 3, | |
| *, | |
| failure_ttl_nim: float = 15.0, | |
| max_failures_nim: int = 2, | |
| failure_ttl_zen: float = 60.0, | |
| max_failures_zen: int = 5, | |
| ) -> None: | |
| self._failure_ttl = failure_ttl | |
| self._max_failures = max_failures | |
| self._failure_ttl_nim = failure_ttl_nim | |
| self._max_failures_nim = max_failures_nim | |
| self._failure_ttl_zen = failure_ttl_zen | |
| self._max_failures_zen = max_failures_zen | |
| self._failures: dict[str, list[float]] = {} | |
| self._failure_ttls: dict[str, float] = {} | |
| self._max_failures_map: dict[str, int] = {} | |
| def get_instance(cls) -> ModelHealthTracker: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def _params_for(self, model_ref: str) -> tuple[float, int]: | |
| """Return (failure_ttl, max_failures) for a model based on provider.""" | |
| if model_ref in self._failure_ttls: | |
| return self._failure_ttls[model_ref], self._max_failures_map[model_ref] | |
| if model_ref.startswith("zen/"): | |
| return self._failure_ttl_zen, self._max_failures_zen | |
| if model_ref.startswith("nvidia_nim/"): | |
| return self._failure_ttl_nim, self._max_failures_nim | |
| return self._failure_ttl, self._max_failures | |
| def record_failure(self, model_ref: str) -> None: | |
| """Record a failure timestamp for a model.""" | |
| now = time.monotonic() | |
| if model_ref not in self._failures: | |
| self._failures[model_ref] = [] | |
| self._failures[model_ref].append(now) | |
| logger.debug("HEALTH: recorded failure for '{}'", model_ref) | |
| def is_healthy(self, model_ref: str) -> bool: | |
| """Check if model has had fewer than max_failures in the TTL window.""" | |
| if model_ref not in self._failures: | |
| return True | |
| ttl, max_f = self._params_for(model_ref) | |
| cutoff = time.monotonic() - ttl | |
| recent = [t for t in self._failures[model_ref] if t > cutoff] | |
| self._failures[model_ref] = recent | |
| healthy = len(recent) < max_f | |
| if not healthy: | |
| logger.debug( | |
| "HEALTH: model '{}' is unhealthy ({} failures in {}s)", | |
| model_ref, | |
| len(recent), | |
| ttl, | |
| ) | |
| return healthy | |
| def get_failure_count(self, model_ref: str) -> int: | |
| """Get number of recent failures for a model.""" | |
| if model_ref not in self._failures: | |
| return 0 | |
| ttl, _ = self._params_for(model_ref) | |
| cutoff = time.monotonic() - ttl | |
| return len([t for t in self._failures[model_ref] if t > cutoff]) | |
| def clear_failures(self, model_ref: str) -> None: | |
| """Clear failure history for a model (on success).""" | |
| if model_ref in self._failures: | |
| self._failures.pop(model_ref) | |
| class GlobalRateLimiter: | |
| """ | |
| Global singleton rate limiter that blocks all requests | |
| when a rate limit error is encountered (reactive) and | |
| throttles requests (proactive) using a strict rolling window. | |
| Optionally enforces a max_concurrency cap: at most N provider streams | |
| may be open simultaneously, independent of the sliding window. | |
| Proactive limits - throttles requests to stay within API limits. | |
| Reactive limits - pauses all requests when a 429 is hit. | |
| Concurrency limit - caps simultaneously open streams. | |
| """ | |
| _instance: ClassVar[GlobalRateLimiter | None] = None | |
| _scoped_instances: ClassVar[dict[str, GlobalRateLimiter]] = {} | |
| def __init__( | |
| self, | |
| rate_limit: int = 40, | |
| rate_window: float = 60.0, | |
| max_concurrency: int = 5, | |
| adaptive_rate: int | None = None, | |
| adaptive_min_rate: int = 10, | |
| ): | |
| # Prevent re-initialization on singleton reuse | |
| if hasattr(self, "_initialized"): | |
| return | |
| if rate_limit <= 0: | |
| raise ValueError("rate_limit must be > 0") | |
| if rate_window <= 0: | |
| raise ValueError("rate_window must be > 0") | |
| if max_concurrency <= 0: | |
| raise ValueError("max_concurrency must be > 0") | |
| self._rate_limit = rate_limit | |
| self._rate_window = float(rate_window) | |
| self._max_concurrency = max_concurrency | |
| self._adaptive_rate = adaptive_rate | |
| self._adaptive_min_rate = adaptive_min_rate | |
| if adaptive_rate is not None: | |
| self._proactive_limiter = AdaptiveRateLimiter( | |
| initial_rate=adaptive_rate, | |
| min_rate=adaptive_min_rate, | |
| window=float(rate_window), | |
| ) | |
| else: | |
| self._proactive_limiter = StrictSlidingWindowLimiter( | |
| rate_limit, float(rate_window) | |
| ) | |
| self._blocked_until: float = 0 | |
| self._concurrency_sem = asyncio.Semaphore(max_concurrency) | |
| self._initialized = True | |
| limiter_type = ( | |
| f"Adaptive({adaptive_rate}→{adaptive_min_rate})" | |
| if adaptive_rate is not None | |
| else f"Strict({rate_limit})" | |
| ) | |
| logger.info( | |
| f"GlobalRateLimiter initialized {limiter_type} / {rate_window}s, max_concurrency={max_concurrency}" | |
| ) | |
| def get_instance( | |
| cls, | |
| rate_limit: int | None = None, | |
| rate_window: float | None = None, | |
| max_concurrency: int = 5, | |
| ) -> GlobalRateLimiter: | |
| """Get or create the singleton instance. | |
| Args: | |
| rate_limit: Requests per window (only used on first creation) | |
| rate_window: Window in seconds (only used on first creation) | |
| max_concurrency: Max simultaneous open streams (only used on first creation) | |
| """ | |
| if cls._instance is None: | |
| cls._instance = cls( | |
| rate_limit=rate_limit or 40, | |
| rate_window=rate_window or 60.0, | |
| max_concurrency=max_concurrency, | |
| ) | |
| return cls._instance | |
| def get_scoped_instance( | |
| cls, | |
| scope: str, | |
| *, | |
| rate_limit: int | None = None, | |
| rate_window: float | None = None, | |
| max_concurrency: int = 5, | |
| adaptive_rate: int | None = None, | |
| adaptive_min_rate: int = 10, | |
| ) -> GlobalRateLimiter: | |
| """Get or create a provider-scoped limiter instance. | |
| Zen gets unlimited adaptive rate (9999) since it has no rate limits. | |
| NIM gets adaptive rate from nim_rate_limit setting. | |
| """ | |
| if not scope: | |
| raise ValueError("scope must be non-empty") | |
| desired_rate_limit = 9999 if scope == "zen" else rate_limit or 40 | |
| desired_rate_window = float(rate_window or 60.0) | |
| existing = cls._scoped_instances.get(scope) | |
| if existing and existing.matches_config( | |
| desired_rate_limit, desired_rate_window, max_concurrency | |
| ): | |
| return existing | |
| if existing: | |
| logger.info( | |
| "Rebuilding provider rate limiter for updated scope '{}'", scope | |
| ) | |
| # Adaptive rate only for NIM (not Zen which is unlimited) | |
| use_adaptive = adaptive_rate if scope == "nvidia_nim" else None | |
| cls._scoped_instances[scope] = cls( | |
| rate_limit=desired_rate_limit, | |
| rate_window=desired_rate_window, | |
| max_concurrency=max_concurrency, | |
| adaptive_rate=use_adaptive, | |
| adaptive_min_rate=adaptive_min_rate, | |
| ) | |
| return cls._scoped_instances[scope] | |
| def reset_instance(cls) -> None: | |
| """Reset singleton (for testing).""" | |
| cls._instance = None | |
| cls._scoped_instances = {} | |
| async def wait_if_blocked(self) -> bool: | |
| """ | |
| Wait if currently rate limited or throttle to meet quota. | |
| Returns: | |
| True if was reactively blocked and waited, False otherwise. | |
| """ | |
| # 1. Reactive check: Wait if someone hit a 429 | |
| waited_reactively = False | |
| now = time.monotonic() | |
| if now < self._blocked_until: | |
| wait_time = self._blocked_until - now | |
| logger.warning( | |
| f"Global provider rate limit active (reactive), waiting {wait_time:.1f}s..." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| waited_reactively = True | |
| # 2. Proactive check: strict rolling window (no bursts beyond N in last W seconds) | |
| await self._acquire_proactive_slot() | |
| return waited_reactively | |
| async def _acquire_proactive_slot(self) -> None: | |
| """ | |
| Acquire a proactive slot enforcing a strict rolling window. | |
| Guarantees: at most `self._rate_limit` acquisitions in any interval of length | |
| `self._rate_window` (seconds). | |
| """ | |
| await self._proactive_limiter.acquire() | |
| def set_blocked(self, seconds: float = 60) -> None: | |
| """ | |
| Set global block for specified seconds (reactive). | |
| Args: | |
| seconds: How long to block (default 60s) | |
| """ | |
| self._blocked_until = time.monotonic() + seconds | |
| logger.warning(f"Global provider rate limit set for {seconds:.1f}s (reactive)") | |
| def is_blocked(self) -> bool: | |
| """Check if currently reactively blocked.""" | |
| return time.monotonic() < self._blocked_until | |
| def matches_config( | |
| self, rate_limit: int, rate_window: float, max_concurrency: int | |
| ) -> bool: | |
| """Return whether this limiter matches the requested runtime config.""" | |
| return ( | |
| self._rate_limit == rate_limit | |
| and self._rate_window == float(rate_window) | |
| and self._max_concurrency == max_concurrency | |
| ) | |
| def remaining_wait(self) -> float: | |
| """Get remaining reactive wait time in seconds.""" | |
| return max(0.0, self._blocked_until - time.monotonic()) | |
| def record_failure(self, model_ref: str | None = None) -> None: | |
| """Record a failure for rate limit tracking. | |
| Args: | |
| model_ref: Optional model identifier for health tracking. | |
| """ | |
| # Record in the shared health tracker if model provided | |
| if model_ref: | |
| health = ModelHealthTracker.get_instance() | |
| health.record_failure(model_ref) | |
| def is_healthy(self, model_ref: str | None = None) -> bool: | |
| """Check if provider/model is healthy based on failure history. | |
| Args: | |
| model_ref: Optional model identifier for health tracking. | |
| Returns: | |
| True if no recent failures or model_ref is None. | |
| """ | |
| if model_ref is None: | |
| return True | |
| health = ModelHealthTracker.get_instance() | |
| return health.is_healthy(model_ref) | |
| async def concurrency_slot(self) -> AsyncIterator[None]: | |
| """Async context manager that holds one concurrency slot for a stream. | |
| Blocks until a slot is available (controlled by max_concurrency). | |
| """ | |
| await self._concurrency_sem.acquire() | |
| try: | |
| yield | |
| finally: | |
| self._concurrency_sem.release() | |
| async def execute_with_retry( | |
| self, | |
| fn: Callable[..., Any], | |
| *args: Any, | |
| max_retries: int = 3, | |
| base_delay: float = 0.3, | |
| max_delay: float = 20.0, | |
| jitter: float = 0.1, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """Execute an async callable with rate limiting and retry on 429. | |
| Waits for the proactive limiter before each attempt. On 429, applies | |
| adaptive backoff and notifies the adaptive rate limiter. Snappier recovery | |
| than fixed delays. | |
| Args: | |
| fn: Async callable to execute. | |
| max_retries: Maximum number of retry attempts after the first failure. | |
| base_delay: Base delay in seconds for exponential backoff. | |
| max_delay: Maximum delay cap in seconds. | |
| jitter: Maximum random jitter in seconds added to each delay. | |
| Returns: | |
| The result of the callable. | |
| Raises: | |
| The last exception if all retries are exhausted. | |
| """ | |
| last_exc: Exception | None = None | |
| for attempt in range(1 + max_retries): | |
| await self.wait_if_blocked() | |
| try: | |
| result = await fn(*args, **kwargs) | |
| # Notify adaptive limiter of success (triggers gradual recovery) | |
| self._record_success_for_adaptive() | |
| return result | |
| except openai.RateLimitError as e: | |
| last_exc = e | |
| self._record_429_for_adaptive() | |
| if attempt >= max_retries: | |
| logger.warning( | |
| f"Rate limit retry exhausted after {max_retries} retries" | |
| ) | |
| break | |
| delay = min(base_delay * (2**attempt), max_delay) | |
| delay += random.uniform(0, jitter) | |
| logger.warning( | |
| f"Rate limited (429), attempt {attempt + 1}/{max_retries + 1}. " | |
| f"Retrying in {delay:.1f}s..." | |
| ) | |
| self.set_blocked(delay) | |
| await asyncio.sleep(delay) | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code != 429: | |
| raise | |
| last_exc = e | |
| self._record_429_for_adaptive() | |
| if attempt >= max_retries: | |
| logger.warning( | |
| f"HTTP 429 retry exhausted after {max_retries} retries" | |
| ) | |
| break | |
| delay = min(base_delay * (2**attempt), max_delay) | |
| delay += random.uniform(0, jitter) | |
| logger.warning( | |
| f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. " | |
| f"Retrying in {delay:.1f}s..." | |
| ) | |
| self.set_blocked(delay) | |
| await asyncio.sleep(delay) | |
| assert last_exc is not None | |
| raise last_exc | |
| def _record_429_for_adaptive(self) -> None: | |
| """Notify adaptive limiter of a 429 — triggers rate backoff.""" | |
| if isinstance(self._proactive_limiter, AdaptiveRateLimiter): | |
| self._proactive_limiter.record_429() | |
| def _record_success_for_adaptive(self) -> None: | |
| """Notify adaptive limiter of success — triggers gradual rate recovery.""" | |
| if isinstance(self._proactive_limiter, AdaptiveRateLimiter): | |
| self._proactive_limiter.record_success() | |