File size: 4,115 Bytes
cfb0fa4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import time
from typing import Optional, Dict
from open_webui.env import REDIS_KEY_PREFIX
class RateLimiter:
"""
General-purpose rate limiter using Redis with a rolling window strategy.
Falls back to in-memory storage if Redis is not available.
"""
# In-memory fallback storage
_memory_store: Dict[str, Dict[int, int]] = {}
def __init__(
self,
redis_client,
limit: int,
window: int,
bucket_size: int = 60,
enabled: bool = True,
):
"""
:param redis_client: Redis client instance or None
:param limit: Max allowed events in the window
:param window: Time window in seconds
:param bucket_size: Bucket resolution
:param enabled: Turn on/off rate limiting globally
"""
self.r = redis_client
self.limit = limit
self.window = window
self.bucket_size = bucket_size
self.num_buckets = window // bucket_size
self.enabled = enabled
def _bucket_key(self, key: str, bucket_index: int) -> str:
return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}"
def _current_bucket(self) -> int:
return int(time.time()) // self.bucket_size
def _redis_available(self) -> bool:
return self.r is not None
def is_limited(self, key: str) -> bool:
"""
Main rate-limit check.
Gracefully handles missing or failing Redis.
"""
if not self.enabled:
return False
if self._redis_available():
try:
return self._is_limited_redis(key)
except Exception:
return self._is_limited_memory(key)
else:
return self._is_limited_memory(key)
def get_count(self, key: str) -> int:
if not self.enabled:
return 0
if self._redis_available():
try:
return self._get_count_redis(key)
except Exception:
return self._get_count_memory(key)
else:
return self._get_count_memory(key)
def remaining(self, key: str) -> int:
used = self.get_count(key)
return max(0, self.limit - used)
def _is_limited_redis(self, key: str) -> bool:
now_bucket = self._current_bucket()
bucket_key = self._bucket_key(key, now_bucket)
attempts = self.r.incr(bucket_key)
if attempts == 1:
self.r.expire(bucket_key, self.window + self.bucket_size)
# Collect buckets
buckets = [
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
]
counts = self.r.mget(buckets)
total = sum(int(c) for c in counts if c)
return total > self.limit
def _get_count_redis(self, key: str) -> int:
now_bucket = self._current_bucket()
buckets = [
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
]
counts = self.r.mget(buckets)
return sum(int(c) for c in counts if c)
def _is_limited_memory(self, key: str) -> bool:
now_bucket = self._current_bucket()
# Init storage
if key not in self._memory_store:
self._memory_store[key] = {}
store = self._memory_store[key]
# Increment bucket
store[now_bucket] = store.get(now_bucket, 0) + 1
# Drop expired buckets
min_bucket = now_bucket - self.num_buckets
expired = [b for b in store if b < min_bucket]
for b in expired:
del store[b]
# Count totals
total = sum(store.values())
return total > self.limit
def _get_count_memory(self, key: str) -> int:
now_bucket = self._current_bucket()
if key not in self._memory_store:
return 0
store = self._memory_store[key]
min_bucket = now_bucket - self.num_buckets
# Remove expired
expired = [b for b in store if b < min_bucket]
for b in expired:
del store[b]
return sum(store.values())
|