| | import time |
| | from typing import Optional, Dict |
| | from open_webui.env import REDIS_KEY_PREFIX |
| |
|
| |
|
| | class RateLimiter: |
| | """ |
| | General-purpose rate limiter using Redis with a rolling window strategy. |
| | Falls back to in-memory storage if Redis is not available. |
| | """ |
| |
|
| | |
| | _memory_store: Dict[str, Dict[int, int]] = {} |
| |
|
| | def __init__( |
| | self, |
| | redis_client, |
| | limit: int, |
| | window: int, |
| | bucket_size: int = 60, |
| | enabled: bool = True, |
| | ): |
| | """ |
| | :param redis_client: Redis client instance or None |
| | :param limit: Max allowed events in the window |
| | :param window: Time window in seconds |
| | :param bucket_size: Bucket resolution |
| | :param enabled: Turn on/off rate limiting globally |
| | """ |
| | self.r = redis_client |
| | self.limit = limit |
| | self.window = window |
| | self.bucket_size = bucket_size |
| | self.num_buckets = window // bucket_size |
| | self.enabled = enabled |
| |
|
| | def _bucket_key(self, key: str, bucket_index: int) -> str: |
| | return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}" |
| |
|
| | def _current_bucket(self) -> int: |
| | return int(time.time()) // self.bucket_size |
| |
|
| | def _redis_available(self) -> bool: |
| | return self.r is not None |
| |
|
| | def is_limited(self, key: str) -> bool: |
| | """ |
| | Main rate-limit check. |
| | Gracefully handles missing or failing Redis. |
| | """ |
| | if not self.enabled: |
| | return False |
| |
|
| | if self._redis_available(): |
| | try: |
| | return self._is_limited_redis(key) |
| | except Exception: |
| | return self._is_limited_memory(key) |
| | else: |
| | return self._is_limited_memory(key) |
| |
|
| | def get_count(self, key: str) -> int: |
| | if not self.enabled: |
| | return 0 |
| |
|
| | if self._redis_available(): |
| | try: |
| | return self._get_count_redis(key) |
| | except Exception: |
| | return self._get_count_memory(key) |
| | else: |
| | return self._get_count_memory(key) |
| |
|
| | def remaining(self, key: str) -> int: |
| | used = self.get_count(key) |
| | return max(0, self.limit - used) |
| |
|
| | def _is_limited_redis(self, key: str) -> bool: |
| | now_bucket = self._current_bucket() |
| | bucket_key = self._bucket_key(key, now_bucket) |
| |
|
| | attempts = self.r.incr(bucket_key) |
| | if attempts == 1: |
| | self.r.expire(bucket_key, self.window + self.bucket_size) |
| |
|
| | |
| | buckets = [ |
| | self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) |
| | ] |
| |
|
| | counts = self.r.mget(buckets) |
| | total = sum(int(c) for c in counts if c) |
| |
|
| | return total > self.limit |
| |
|
| | def _get_count_redis(self, key: str) -> int: |
| | now_bucket = self._current_bucket() |
| | buckets = [ |
| | self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) |
| | ] |
| | counts = self.r.mget(buckets) |
| | return sum(int(c) for c in counts if c) |
| |
|
| | def _is_limited_memory(self, key: str) -> bool: |
| | now_bucket = self._current_bucket() |
| |
|
| | |
| | if key not in self._memory_store: |
| | self._memory_store[key] = {} |
| |
|
| | store = self._memory_store[key] |
| |
|
| | |
| | store[now_bucket] = store.get(now_bucket, 0) + 1 |
| |
|
| | |
| | min_bucket = now_bucket - self.num_buckets |
| | expired = [b for b in store if b < min_bucket] |
| | for b in expired: |
| | del store[b] |
| |
|
| | |
| | total = sum(store.values()) |
| | return total > self.limit |
| |
|
| | def _get_count_memory(self, key: str) -> int: |
| | now_bucket = self._current_bucket() |
| | if key not in self._memory_store: |
| | return 0 |
| |
|
| | store = self._memory_store[key] |
| | min_bucket = now_bucket - self.num_buckets |
| |
|
| | |
| | expired = [b for b in store if b < min_bucket] |
| | for b in expired: |
| | del store[b] |
| |
|
| | return sum(store.values()) |
| |
|