claude-code-proxy / providers /rate_limit.py
Yash030's picture
NIM speed optimization — adaptive rate limiting and increased throughput
aa9c0b0
"""Global rate limiter for API requests."""
import asyncio
import random
import time
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from typing import Any, ClassVar, TypeVar
import httpx
import openai
from loguru import logger
from core.rate_limit import StrictSlidingWindowLimiter
T = TypeVar("T")
class AdaptiveRateLimiter:
"""Adaptive rate limiter that backs off on 429s and recovers gradually.
Starts at a high throughput and auto-adjusts based on upstream feedback.
This gives maximum throughput in normal conditions while self-correcting
when rate limits are hit.
"""
_limiter_count: ClassVar[int] = 0
def __init__(
self,
initial_rate: int = 100,
min_rate: int = 10,
window: float = 60.0,
backoff_factor: float = 0.5,
recovery_factor: float = 1.2,
) -> None:
self._initial_rate = initial_rate
self._current_rate = initial_rate
self._min_rate = min_rate
self._window = window
self._backoff_factor = backoff_factor
self._recovery_factor = recovery_factor
self._limiter = StrictSlidingWindowLimiter(initial_rate, window)
self._lock = asyncio.Lock()
self._success_streak: int = 0
self._instance_id = AdaptiveRateLimiter._limiter_count
AdaptiveRateLimiter._limiter_count += 1
async def acquire(self) -> None:
await self._limiter.acquire()
def record_429(self) -> None:
"""Called when a 429 is received — reduce rate immediately."""
self._current_rate = max(
self._min_rate, int(self._current_rate * self._backoff_factor)
)
self._limiter = StrictSlidingWindowLimiter(self._current_rate, self._window)
self._success_streak = 0
logger.warning(
"ADAPTIVE_RATE: instance={} backed off to {} req/min (429 received)",
self._instance_id,
self._current_rate,
)
def record_success(self) -> None:
"""Called on success — gradually recover rate if below initial."""
if self._current_rate >= self._initial_rate:
self._success_streak = 0
return
self._success_streak += 1
# Recover after 3 consecutive successes
if self._success_streak >= 3:
self._current_rate = min(
self._initial_rate,
int(self._current_rate * self._recovery_factor),
)
self._limiter = StrictSlidingWindowLimiter(self._current_rate, self._window)
self._success_streak = 0
logger.info(
"ADAPTIVE_RATE: instance={} recovered to {} req/min",
self._instance_id,
self._current_rate,
)
class ModelHealthTracker:
"""Track per-model health based on recent failures."""
_instance: ClassVar[ModelHealthTracker | None] = None
def __init__(
self,
failure_ttl: float = 30.0,
max_failures: int = 3,
*,
failure_ttl_nim: float = 15.0,
max_failures_nim: int = 2,
failure_ttl_zen: float = 60.0,
max_failures_zen: int = 5,
) -> None:
self._failure_ttl = failure_ttl
self._max_failures = max_failures
self._failure_ttl_nim = failure_ttl_nim
self._max_failures_nim = max_failures_nim
self._failure_ttl_zen = failure_ttl_zen
self._max_failures_zen = max_failures_zen
self._failures: dict[str, list[float]] = {}
self._failure_ttls: dict[str, float] = {}
self._max_failures_map: dict[str, int] = {}
@classmethod
def get_instance(cls) -> ModelHealthTracker:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _params_for(self, model_ref: str) -> tuple[float, int]:
"""Return (failure_ttl, max_failures) for a model based on provider."""
if model_ref in self._failure_ttls:
return self._failure_ttls[model_ref], self._max_failures_map[model_ref]
if model_ref.startswith("zen/"):
return self._failure_ttl_zen, self._max_failures_zen
if model_ref.startswith("nvidia_nim/"):
return self._failure_ttl_nim, self._max_failures_nim
return self._failure_ttl, self._max_failures
def record_failure(self, model_ref: str) -> None:
"""Record a failure timestamp for a model."""
now = time.monotonic()
if model_ref not in self._failures:
self._failures[model_ref] = []
self._failures[model_ref].append(now)
logger.debug("HEALTH: recorded failure for '{}'", model_ref)
def is_healthy(self, model_ref: str) -> bool:
"""Check if model has had fewer than max_failures in the TTL window."""
if model_ref not in self._failures:
return True
ttl, max_f = self._params_for(model_ref)
cutoff = time.monotonic() - ttl
recent = [t for t in self._failures[model_ref] if t > cutoff]
self._failures[model_ref] = recent
healthy = len(recent) < max_f
if not healthy:
logger.debug(
"HEALTH: model '{}' is unhealthy ({} failures in {}s)",
model_ref,
len(recent),
ttl,
)
return healthy
def get_failure_count(self, model_ref: str) -> int:
"""Get number of recent failures for a model."""
if model_ref not in self._failures:
return 0
ttl, _ = self._params_for(model_ref)
cutoff = time.monotonic() - ttl
return len([t for t in self._failures[model_ref] if t > cutoff])
def clear_failures(self, model_ref: str) -> None:
"""Clear failure history for a model (on success)."""
if model_ref in self._failures:
self._failures.pop(model_ref)
class GlobalRateLimiter:
"""
Global singleton rate limiter that blocks all requests
when a rate limit error is encountered (reactive) and
throttles requests (proactive) using a strict rolling window.
Optionally enforces a max_concurrency cap: at most N provider streams
may be open simultaneously, independent of the sliding window.
Proactive limits - throttles requests to stay within API limits.
Reactive limits - pauses all requests when a 429 is hit.
Concurrency limit - caps simultaneously open streams.
"""
_instance: ClassVar[GlobalRateLimiter | None] = None
_scoped_instances: ClassVar[dict[str, GlobalRateLimiter]] = {}
def __init__(
self,
rate_limit: int = 40,
rate_window: float = 60.0,
max_concurrency: int = 5,
adaptive_rate: int | None = None,
adaptive_min_rate: int = 10,
):
# Prevent re-initialization on singleton reuse
if hasattr(self, "_initialized"):
return
if rate_limit <= 0:
raise ValueError("rate_limit must be > 0")
if rate_window <= 0:
raise ValueError("rate_window must be > 0")
if max_concurrency <= 0:
raise ValueError("max_concurrency must be > 0")
self._rate_limit = rate_limit
self._rate_window = float(rate_window)
self._max_concurrency = max_concurrency
self._adaptive_rate = adaptive_rate
self._adaptive_min_rate = adaptive_min_rate
if adaptive_rate is not None:
self._proactive_limiter = AdaptiveRateLimiter(
initial_rate=adaptive_rate,
min_rate=adaptive_min_rate,
window=float(rate_window),
)
else:
self._proactive_limiter = StrictSlidingWindowLimiter(
rate_limit, float(rate_window)
)
self._blocked_until: float = 0
self._concurrency_sem = asyncio.Semaphore(max_concurrency)
self._initialized = True
limiter_type = (
f"Adaptive({adaptive_rate}{adaptive_min_rate})"
if adaptive_rate is not None
else f"Strict({rate_limit})"
)
logger.info(
f"GlobalRateLimiter initialized {limiter_type} / {rate_window}s, max_concurrency={max_concurrency}"
)
@classmethod
def get_instance(
cls,
rate_limit: int | None = None,
rate_window: float | None = None,
max_concurrency: int = 5,
) -> GlobalRateLimiter:
"""Get or create the singleton instance.
Args:
rate_limit: Requests per window (only used on first creation)
rate_window: Window in seconds (only used on first creation)
max_concurrency: Max simultaneous open streams (only used on first creation)
"""
if cls._instance is None:
cls._instance = cls(
rate_limit=rate_limit or 40,
rate_window=rate_window or 60.0,
max_concurrency=max_concurrency,
)
return cls._instance
@classmethod
def get_scoped_instance(
cls,
scope: str,
*,
rate_limit: int | None = None,
rate_window: float | None = None,
max_concurrency: int = 5,
adaptive_rate: int | None = None,
adaptive_min_rate: int = 10,
) -> GlobalRateLimiter:
"""Get or create a provider-scoped limiter instance.
Zen gets unlimited adaptive rate (9999) since it has no rate limits.
NIM gets adaptive rate from nim_rate_limit setting.
"""
if not scope:
raise ValueError("scope must be non-empty")
desired_rate_limit = 9999 if scope == "zen" else rate_limit or 40
desired_rate_window = float(rate_window or 60.0)
existing = cls._scoped_instances.get(scope)
if existing and existing.matches_config(
desired_rate_limit, desired_rate_window, max_concurrency
):
return existing
if existing:
logger.info(
"Rebuilding provider rate limiter for updated scope '{}'", scope
)
# Adaptive rate only for NIM (not Zen which is unlimited)
use_adaptive = adaptive_rate if scope == "nvidia_nim" else None
cls._scoped_instances[scope] = cls(
rate_limit=desired_rate_limit,
rate_window=desired_rate_window,
max_concurrency=max_concurrency,
adaptive_rate=use_adaptive,
adaptive_min_rate=adaptive_min_rate,
)
return cls._scoped_instances[scope]
@classmethod
def reset_instance(cls) -> None:
"""Reset singleton (for testing)."""
cls._instance = None
cls._scoped_instances = {}
async def wait_if_blocked(self) -> bool:
"""
Wait if currently rate limited or throttle to meet quota.
Returns:
True if was reactively blocked and waited, False otherwise.
"""
# 1. Reactive check: Wait if someone hit a 429
waited_reactively = False
now = time.monotonic()
if now < self._blocked_until:
wait_time = self._blocked_until - now
logger.warning(
f"Global provider rate limit active (reactive), waiting {wait_time:.1f}s..."
)
await asyncio.sleep(wait_time)
waited_reactively = True
# 2. Proactive check: strict rolling window (no bursts beyond N in last W seconds)
await self._acquire_proactive_slot()
return waited_reactively
async def _acquire_proactive_slot(self) -> None:
"""
Acquire a proactive slot enforcing a strict rolling window.
Guarantees: at most `self._rate_limit` acquisitions in any interval of length
`self._rate_window` (seconds).
"""
await self._proactive_limiter.acquire()
def set_blocked(self, seconds: float = 60) -> None:
"""
Set global block for specified seconds (reactive).
Args:
seconds: How long to block (default 60s)
"""
self._blocked_until = time.monotonic() + seconds
logger.warning(f"Global provider rate limit set for {seconds:.1f}s (reactive)")
def is_blocked(self) -> bool:
"""Check if currently reactively blocked."""
return time.monotonic() < self._blocked_until
def matches_config(
self, rate_limit: int, rate_window: float, max_concurrency: int
) -> bool:
"""Return whether this limiter matches the requested runtime config."""
return (
self._rate_limit == rate_limit
and self._rate_window == float(rate_window)
and self._max_concurrency == max_concurrency
)
def remaining_wait(self) -> float:
"""Get remaining reactive wait time in seconds."""
return max(0.0, self._blocked_until - time.monotonic())
def record_failure(self, model_ref: str | None = None) -> None:
"""Record a failure for rate limit tracking.
Args:
model_ref: Optional model identifier for health tracking.
"""
# Record in the shared health tracker if model provided
if model_ref:
health = ModelHealthTracker.get_instance()
health.record_failure(model_ref)
def is_healthy(self, model_ref: str | None = None) -> bool:
"""Check if provider/model is healthy based on failure history.
Args:
model_ref: Optional model identifier for health tracking.
Returns:
True if no recent failures or model_ref is None.
"""
if model_ref is None:
return True
health = ModelHealthTracker.get_instance()
return health.is_healthy(model_ref)
@asynccontextmanager
async def concurrency_slot(self) -> AsyncIterator[None]:
"""Async context manager that holds one concurrency slot for a stream.
Blocks until a slot is available (controlled by max_concurrency).
"""
await self._concurrency_sem.acquire()
try:
yield
finally:
self._concurrency_sem.release()
async def execute_with_retry(
self,
fn: Callable[..., Any],
*args: Any,
max_retries: int = 3,
base_delay: float = 0.3,
max_delay: float = 20.0,
jitter: float = 0.1,
**kwargs: Any,
) -> Any:
"""Execute an async callable with rate limiting and retry on 429.
Waits for the proactive limiter before each attempt. On 429, applies
adaptive backoff and notifies the adaptive rate limiter. Snappier recovery
than fixed delays.
Args:
fn: Async callable to execute.
max_retries: Maximum number of retry attempts after the first failure.
base_delay: Base delay in seconds for exponential backoff.
max_delay: Maximum delay cap in seconds.
jitter: Maximum random jitter in seconds added to each delay.
Returns:
The result of the callable.
Raises:
The last exception if all retries are exhausted.
"""
last_exc: Exception | None = None
for attempt in range(1 + max_retries):
await self.wait_if_blocked()
try:
result = await fn(*args, **kwargs)
# Notify adaptive limiter of success (triggers gradual recovery)
self._record_success_for_adaptive()
return result
except openai.RateLimitError as e:
last_exc = e
self._record_429_for_adaptive()
if attempt >= max_retries:
logger.warning(
f"Rate limit retry exhausted after {max_retries} retries"
)
break
delay = min(base_delay * (2**attempt), max_delay)
delay += random.uniform(0, jitter)
logger.warning(
f"Rate limited (429), attempt {attempt + 1}/{max_retries + 1}. "
f"Retrying in {delay:.1f}s..."
)
self.set_blocked(delay)
await asyncio.sleep(delay)
except httpx.HTTPStatusError as e:
if e.response.status_code != 429:
raise
last_exc = e
self._record_429_for_adaptive()
if attempt >= max_retries:
logger.warning(
f"HTTP 429 retry exhausted after {max_retries} retries"
)
break
delay = min(base_delay * (2**attempt), max_delay)
delay += random.uniform(0, jitter)
logger.warning(
f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. "
f"Retrying in {delay:.1f}s..."
)
self.set_blocked(delay)
await asyncio.sleep(delay)
assert last_exc is not None
raise last_exc
def _record_429_for_adaptive(self) -> None:
"""Notify adaptive limiter of a 429 — triggers rate backoff."""
if isinstance(self._proactive_limiter, AdaptiveRateLimiter):
self._proactive_limiter.record_429()
def _record_success_for_adaptive(self) -> None:
"""Notify adaptive limiter of success — triggers gradual rate recovery."""
if isinstance(self._proactive_limiter, AdaptiveRateLimiter):
self._proactive_limiter.record_success()