secureagentrag-api / utils /rate_limiter.py
LeomordKaly's picture
deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)
f4ef3b8 verified
"""Token-bucket rate limiter for API request throttling.
Provides per-user and per-endpoint rate limiting to prevent abuse and
ensure fair resource allocation. Uses an in-memory token bucket algorithm
with optional Redis backend for distributed deployments.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
from utils.logging import get_logger
logger = get_logger(__name__)
@dataclass
class RateLimitConfig:
"""Configuration for a rate limit bucket.
Attributes:
requests_per_minute: Maximum requests allowed per minute.
burst_size: Maximum burst capacity (bucket size).
cooldown_seconds: Seconds to wait after being rate limited.
"""
requests_per_minute: int = 60
burst_size: int = 10
cooldown_seconds: float = 1.0
@dataclass
class TokenBucket:
"""In-memory token bucket for rate limiting.
Attributes:
tokens: Current available tokens.
last_update: Timestamp of last token refill.
config: Rate limit configuration.
blocked_until: Timestamp when the bucket is unblocked.
"""
tokens: float = field(default=0.0)
last_update: float = field(default_factory=time.time)
config: RateLimitConfig = field(default_factory=RateLimitConfig)
blocked_until: float = field(default=0.0)
def _refill(self) -> None:
"""Refill tokens based on elapsed time since last update."""
now = time.time()
elapsed = now - self.last_update
# Refill rate: tokens per second
refill_rate = self.config.requests_per_minute / 60.0
self.tokens = min(self.config.burst_size, self.tokens + elapsed * refill_rate)
self.last_update = now
def consume(self, tokens: float = 1.0) -> tuple[bool, dict[str, Any]]:
"""Attempt to consume tokens from the bucket.
Args:
tokens: Number of tokens to consume (default 1 per request).
Returns:
Tuple of (allowed, metadata) where metadata contains
remaining tokens, retry_after, etc.
"""
now = time.time()
# Check if currently blocked
if now < self.blocked_until:
retry_after = int(self.blocked_until - now) + 1
return False, {
"allowed": False,
"remaining": 0,
"retry_after": retry_after,
"reason": "cooldown_active",
}
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
remaining = int(self.tokens)
return True, {
"allowed": True,
"remaining": remaining,
"retry_after": 0,
"reason": None,
}
# Rate limit exceeded — enter cooldown
self.blocked_until = now + self.config.cooldown_seconds
retry_after = int(self.config.cooldown_seconds) + 1
return False, {
"allowed": False,
"remaining": 0,
"retry_after": retry_after,
"reason": "rate_limit_exceeded",
}
class RateLimiter:
"""Multi-key rate limiter with per-user and per-endpoint tracking.
Uses in-memory token buckets. For distributed deployments, wrap
with a Redis-backed implementation.
Args:
default_config: Default rate limit configuration.
"""
def __init__(self, default_config: RateLimitConfig | None = None) -> None:
"""Initialize the rate limiter.
Args:
default_config: Default configuration for new buckets.
"""
self._default_config = default_config or RateLimitConfig()
self._buckets: dict[str, TokenBucket] = {}
def _get_bucket(self, key: str, config: RateLimitConfig | None = None) -> TokenBucket:
"""Get or create a token bucket for the given key.
Args:
key: Unique identifier for the bucket (e.g., user_id + endpoint).
config: Optional custom configuration.
Returns:
The token bucket for the key.
"""
if key not in self._buckets:
self._buckets[key] = TokenBucket(
tokens=config.burst_size if config else self._default_config.burst_size,
config=config or self._default_config,
)
return self._buckets[key]
def check_rate_limit(
self,
key: str,
tokens: float = 1.0,
config: RateLimitConfig | None = None,
) -> tuple[bool, dict[str, Any]]:
"""Check if a request is within the rate limit.
Args:
key: Rate limit bucket key (e.g., "user_123:query").
tokens: Tokens to consume.
config: Optional custom config for this key.
Returns:
Tuple of (allowed, metadata).
"""
bucket = self._get_bucket(key, config)
allowed, metadata = bucket.consume(tokens)
if not allowed:
logger.warning(
"rate_limit_exceeded",
key=key,
retry_after=metadata["retry_after"],
reason=metadata["reason"],
)
else:
logger.debug(
"rate_limit_allowed",
key=key,
remaining=metadata["remaining"],
)
return allowed, metadata
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
"""Simple check — returns True if request is allowed.
Args:
key: Rate limit bucket key.
tokens: Tokens to consume.
Returns:
True if within rate limit, False otherwise.
"""
allowed, _ = self.check_rate_limit(key, tokens)
return allowed
def get_status(self, key: str) -> dict[str, Any]:
"""Get current rate limit status for a key.
Args:
key: Rate limit bucket key.
Returns:
Dict with remaining tokens, reset time, etc.
"""
bucket = self._buckets.get(key)
if not bucket:
return {
"remaining": self._default_config.burst_size,
"limit": self._default_config.requests_per_minute,
"reset": 0,
}
bucket._refill()
return {
"remaining": int(bucket.tokens),
"limit": bucket.config.requests_per_minute,
"reset": int(max(0, bucket.blocked_until - time.time())),
}
def reset(self, key: str) -> None:
"""Reset a specific rate limit bucket.
Args:
key: Bucket key to reset.
"""
if key in self._buckets:
del self._buckets[key]
logger.info("rate_limit_reset", key=key)
class OwnerKeyHourThrottle:
"""Per-IP hourly throttle for the BYOK owner-key fallback.
Distinct from the request-level :class:`RateLimiter` because the BYOK
semantics are different:
- Visitors who bring their own LLM key (``ByokCreds.has_user_key()``)
bypass this throttle entirely — they are paying for their own tokens.
- Visitors who do NOT bring a key fall back to the platform owner's
Groq key. This throttle exists to stop a single recruiter or curious
visitor from burning the free-tier 30 RPM / 14,400 RPD budget.
Bucket window is rolling one hour from the first allowed request in
the window. Sliding-window precision is not needed — three requests an
hour is already conservative. We keep timestamps in a tiny list per IP
and prune entries older than 3600 seconds on each check.
"""
__slots__ = ("_buckets", "_quota_per_hour")
def __init__(self, quota_per_hour: int) -> None:
if quota_per_hour < 0:
raise ValueError("quota_per_hour must be non-negative")
self._quota_per_hour = quota_per_hour
self._buckets: dict[str, list[float]] = {}
def allow(self, ip: str, *, now: float | None = None) -> tuple[bool, dict[str, Any]]:
"""Return whether ``ip`` may consume one owner-key request.
Args:
ip: Client IP address (use ``"anon"`` when unavailable so the
fallback path still throttles instead of leaking quota).
now: Optional monotonic clock override for tests.
Returns:
``(allowed, meta)`` where ``meta`` carries ``remaining`` and
``retry_after`` seconds, ready for an HTTP 429 response.
"""
t = now if now is not None else time.monotonic()
# Prune entries older than 1h, then count.
bucket = [ts for ts in self._buckets.get(ip, []) if t - ts < 3600.0]
if len(bucket) >= self._quota_per_hour:
# ``retry_after`` defaults to a full window when quota_per_hour=0
# (kill switch) — there is no "oldest entry" to expire.
retry_after = max(1, int(3600.0 - (t - bucket[0])) + 1) if bucket else 3600
self._buckets[ip] = bucket # write pruned list back
return False, {
"allowed": False,
"remaining": 0,
"retry_after": retry_after,
"reason": "owner_key_hourly_quota_exhausted",
}
bucket.append(t)
self._buckets[ip] = bucket
return True, {
"allowed": True,
"remaining": self._quota_per_hour - len(bucket),
"retry_after": 0,
"reason": None,
}
def reset(self, ip: str) -> None:
"""Drop all timestamps for ``ip`` (test/cleanup helper)."""
self._buckets.pop(ip, None)
def reset_all(self) -> None:
"""Drop every bucket — used between test cases to avoid leakage."""
self._buckets.clear()
# Module-level singleton — lazy-initialised from settings on first use so
# unit tests that monkey-patch SAR_BYOK_OWNER_QUOTA see the right value.
_owner_key_throttle: OwnerKeyHourThrottle | None = None
def get_owner_key_throttle() -> OwnerKeyHourThrottle:
"""Return the process-wide owner-key throttle, creating it lazily.
Reads ``settings.byok_owner_key_quota_per_hour`` at first call. Tests
that need a different quota value should call :func:`reset_owner_key_throttle`
after the monkey-patch.
"""
global _owner_key_throttle
if _owner_key_throttle is None:
from config.settings import settings # local import to avoid cycle
_owner_key_throttle = OwnerKeyHourThrottle(
quota_per_hour=settings.byok_owner_key_quota_per_hour,
)
return _owner_key_throttle
def reset_owner_key_throttle() -> None:
"""Force the next :func:`get_owner_key_throttle` call to rebuild from settings.
Test-only hook; production code never calls this.
"""
global _owner_key_throttle
_owner_key_throttle = None
class RedisRateLimiter:
"""Distributed rate limiter backed by Redis.
Uses Redis sorted sets with sliding window algorithm for accurate
per-user rate limiting across multiple application instances.
Args:
redis_url: Redis connection URL.
default_config: Default rate limit configuration.
"""
def __init__(
self,
redis_url: str | None = None,
default_config: RateLimitConfig | None = None,
) -> None:
"""Initialize the Redis rate limiter.
Args:
redis_url: Redis connection URL. Falls back to settings.
default_config: Default configuration for new keys.
"""
import redis
from config.settings import settings
self._redis = redis.from_url(redis_url or settings.redis_url)
self._default_config = default_config or RateLimitConfig()
def check_rate_limit(
self,
key: str,
tokens: float = 1.0,
config: RateLimitConfig | None = None,
) -> tuple[bool, dict[str, Any]]:
"""Check if a request is within the rate limit using Redis.
Uses a sliding window algorithm based on Redis sorted sets.
Args:
key: Rate limit bucket key.
tokens: Tokens to consume.
config: Optional custom config.
Returns:
Tuple of (allowed, metadata).
"""
cfg = config or self._default_config
now = time.time()
window_start = now - 60.0 # 1-minute window
redis_key = f"ratelimit:{key}"
# Remove old entries outside the window
self._redis.zremrangebyscore(redis_key, 0, window_start)
# Count current requests in window
current_count = self._redis.zcard(redis_key)
# Check burst limit
if current_count >= cfg.burst_size:
retry_after = int(cfg.cooldown_seconds) + 1
return False, {
"allowed": False,
"remaining": 0,
"retry_after": retry_after,
"reason": "rate_limit_exceeded",
}
# Check per-minute rate
rpm_limit = cfg.requests_per_minute
if current_count >= rpm_limit:
retry_after = int(60 - (now % 60)) + 1
return False, {
"allowed": False,
"remaining": 0,
"retry_after": retry_after,
"reason": "rate_limit_exceeded",
}
# Record this request
self._redis.zadd(redis_key, {str(now): now})
# Set expiry on the key
self._redis.expire(redis_key, 120)
remaining = min(cfg.burst_size, rpm_limit) - current_count - 1
return True, {
"allowed": True,
"remaining": max(0, remaining),
"retry_after": 0,
"reason": None,
}
def is_allowed(self, key: str, tokens: float = 1.0) -> bool:
"""Simple check — returns True if request is allowed.
Args:
key: Rate limit bucket key.
tokens: Tokens to consume.
Returns:
True if within rate limit, False otherwise.
"""
allowed, _ = self.check_rate_limit(key, tokens)
return allowed
def get_status(self, key: str) -> dict[str, Any]:
"""Get current rate limit status for a key.
Args:
key: Rate limit bucket key.
Returns:
Dict with remaining tokens, limit, and reset time.
"""
redis_key = f"ratelimit:{key}"
now = time.time()
window_start = now - 60.0
self._redis.zremrangebyscore(redis_key, 0, window_start)
current_count = self._redis.zcard(redis_key)
remaining = max(0, self._default_config.burst_size - current_count)
return {
"remaining": remaining,
"limit": self._default_config.requests_per_minute,
"reset": int(60 - (now % 60)),
}
def reset(self, key: str) -> None:
"""Reset a specific rate limit bucket.
Args:
key: Bucket key to reset.
"""
self._redis.delete(f"ratelimit:{key}")
logger.info("redis_rate_limit_reset", key=key)
def _get_rate_limiter() -> RateLimiter | RedisRateLimiter:
"""Get the appropriate rate limiter based on configuration.
Returns:
RateLimiter (in-memory) or RedisRateLimiter (distributed).
"""
from config.settings import settings
if settings.use_redis_rate_limiter:
try:
return RedisRateLimiter()
except Exception as exc:
logger.warning("redis_rate_limiter_failed", error=str(exc), fallback="memory")
return RateLimiter(default_config=RATE_LIMIT_PROFILES["default"])
# Pre-configured rate limit profiles
RATE_LIMIT_PROFILES: dict[str, RateLimitConfig] = {
"default": RateLimitConfig(requests_per_minute=60, burst_size=10),
"strict": RateLimitConfig(requests_per_minute=10, burst_size=3, cooldown_seconds=5.0),
"generous": RateLimitConfig(requests_per_minute=300, burst_size=50),
"upload": RateLimitConfig(requests_per_minute=5, burst_size=2, cooldown_seconds=10.0),
"query": RateLimitConfig(requests_per_minute=30, burst_size=5, cooldown_seconds=2.0),
}
# Module-level singleton (lazy initialization)
_rate_limiter_instance: RateLimiter | RedisRateLimiter | None = None
def _get_limiter() -> RateLimiter | RedisRateLimiter:
"""Get the singleton rate limiter instance.
Returns:
The configured rate limiter.
"""
global _rate_limiter_instance
if _rate_limiter_instance is None:
_rate_limiter_instance = _get_rate_limiter()
return _rate_limiter_instance
def check_query_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
"""Check rate limit for a user query.
Args:
user_id: The user making the query.
Returns:
Tuple of (allowed, metadata).
"""
key = f"{user_id}:query"
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["query"])
def check_upload_rate_limit(user_id: str) -> tuple[bool, dict[str, Any]]:
"""Check rate limit for a document upload.
Args:
user_id: The user uploading.
Returns:
Tuple of (allowed, metadata).
"""
key = f"{user_id}:upload"
return _get_limiter().check_rate_limit(key, config=RATE_LIMIT_PROFILES["upload"])