Spaces:
Paused
Paused
File size: 7,163 Bytes
55bd140 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import inspect
from urllib.parse import urlparse
import logging
import redis
from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
log = logging.getLogger(__name__)
_CONNECTION_CACHE = {}
class SentinelRedisProxy:
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
self._sentinel = sentinel
self._service = service
self._kw = kw
self._async_mode = async_mode
def _master(self):
return self._sentinel.master_for(self._service, **self._kw)
def __getattr__(self, item):
master = self._master()
orig_attr = getattr(master, item)
if not callable(orig_attr):
return orig_attr
FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
if item in FACTORY_METHODS:
return orig_attr
if self._async_mode:
async def _wrapped(*args, **kwargs):
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
try:
method = getattr(self._master(), item)
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
return await result
return result
except (
redis.exceptions.ConnectionError,
redis.exceptions.ReadOnlyError,
) as e:
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
log.debug(
"Redis sentinel fail-over (%s). Retry %s/%s",
type(e).__name__,
i + 1,
REDIS_SENTINEL_MAX_RETRY_COUNT,
)
continue
log.error(
"Redis operation failed after %s retries: %s",
REDIS_SENTINEL_MAX_RETRY_COUNT,
e,
)
raise e from e
return _wrapped
else:
def _wrapped(*args, **kwargs):
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
try:
method = getattr(self._master(), item)
return method(*args, **kwargs)
except (
redis.exceptions.ConnectionError,
redis.exceptions.ReadOnlyError,
) as e:
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
log.debug(
"Redis sentinel fail-over (%s). Retry %s/%s",
type(e).__name__,
i + 1,
REDIS_SENTINEL_MAX_RETRY_COUNT,
)
continue
log.error(
"Redis operation failed after %s retries: %s",
REDIS_SENTINEL_MAX_RETRY_COUNT,
e,
)
raise e from e
return _wrapped
def parse_redis_service_url(redis_url):
parsed_url = urlparse(redis_url)
if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
return {
"username": parsed_url.username or None,
"password": parsed_url.password or None,
"service": parsed_url.hostname or "mymaster",
"port": parsed_url.port or 6379,
"db": int(parsed_url.path.lstrip("/") or 0),
}
def get_redis_connection(
redis_url,
redis_sentinels,
redis_cluster=False,
async_mode=False,
decode_responses=True,
):
cache_key = (
redis_url,
tuple(redis_sentinels) if redis_sentinels else (),
async_mode,
decode_responses,
)
if cache_key in _CONNECTION_CACHE:
return _CONNECTION_CACHE[cache_key]
connection = None
if async_mode:
import redis.asyncio as redis
# If using sentinel in async mode
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
connection = SentinelRedisProxy(
sentinel,
redis_config["service"],
async_mode=async_mode,
)
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url:
connection = redis.from_url(redis_url, decode_responses=decode_responses)
else:
import redis
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
connection = SentinelRedisProxy(
sentinel,
redis_config["service"],
async_mode=async_mode,
)
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url:
connection = redis.Redis.from_url(
redis_url, decode_responses=decode_responses
)
_CONNECTION_CACHE[cache_key] = connection
return connection
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
if sentinel_hosts_env:
sentinel_hosts = sentinel_hosts_env.split(",")
sentinel_port = int(sentinel_port_env)
return [(host, sentinel_port) for host in sentinel_hosts]
return []
def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
redis_config = parse_redis_service_url(redis_url)
username = redis_config["username"] or ""
password = redis_config["password"] or ""
auth_part = ""
if username or password:
auth_part = f"{username}:{password}@"
hosts_part = ",".join(
f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
)
return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"
|