Spaces:
Build error
Build error
| import inspect | |
| from urllib.parse import urlparse | |
| import asyncio | |
| import time | |
| import logging | |
| import redis | |
| from open_webui.env import ( | |
| REDIS_CLUSTER, | |
| REDIS_HEALTH_CHECK_INTERVAL, | |
| REDIS_SOCKET_CONNECT_TIMEOUT, | |
| REDIS_SOCKET_KEEPALIVE, | |
| REDIS_SENTINEL_HOSTS, | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| REDIS_SENTINEL_PORT, | |
| REDIS_URL, | |
| REDIS_RECONNECT_DELAY, | |
| ) | |
| log = logging.getLogger(__name__) | |
| MAX_RETRY_COUNT = REDIS_SENTINEL_MAX_RETRY_COUNT | |
| # Let not our connections be timed out but deliver them from | |
| # partition. For the cache and the socket and the uptime | |
| # belong to the one who first opened them, now and always. | |
| _CONNECTION_CACHE = {} | |
| class SentinelRedisProxy: | |
| def __init__(self, sentinel, service, *, async_mode: bool = True, **kw): | |
| self._sentinel = sentinel | |
| self._service = service | |
| self._kw = kw | |
| self._async_mode = async_mode | |
| def _master(self): | |
| return self._sentinel.master_for(self._service, **self._kw) | |
| def __getattr__(self, item): | |
| master = self._master() | |
| orig_attr = getattr(master, item) | |
| if not callable(orig_attr): | |
| return orig_attr | |
| FACTORY_METHODS = {'pipeline', 'pubsub', 'monitor', 'client', 'transaction'} | |
| if item in FACTORY_METHODS: | |
| return orig_attr | |
| if self._async_mode: | |
| if inspect.isasyncgenfunction(orig_attr): | |
| def _wrapped_iter(*args, **kwargs): | |
| async def _iter(): | |
| for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): | |
| try: | |
| method = getattr(self._master(), item) | |
| async for value in method(*args, **kwargs): | |
| yield value | |
| return | |
| except ( | |
| redis.exceptions.ConnectionError, | |
| redis.exceptions.ReadOnlyError, | |
| ) as e: | |
| if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: | |
| log.debug( | |
| 'Redis sentinel fail-over (%s). Retry %s/%s', | |
| type(e).__name__, | |
| i + 1, | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| ) | |
| if REDIS_RECONNECT_DELAY: | |
| time.sleep(REDIS_RECONNECT_DELAY / 1000) | |
| continue | |
| log.error( | |
| 'Redis operation failed after %s retries: %s', | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| e, | |
| ) | |
| raise e from e | |
| return _iter() | |
| return _wrapped_iter | |
| async def _wrapped(*args, **kwargs): | |
| for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): | |
| try: | |
| method = getattr(self._master(), item) | |
| result = method(*args, **kwargs) | |
| if inspect.iscoroutine(result): | |
| return await result | |
| return result | |
| except ( | |
| redis.exceptions.ConnectionError, | |
| redis.exceptions.ReadOnlyError, | |
| ) as e: | |
| if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: | |
| log.debug( | |
| 'Redis sentinel fail-over (%s). Retry %s/%s', | |
| type(e).__name__, | |
| i + 1, | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| ) | |
| if REDIS_RECONNECT_DELAY: | |
| await asyncio.sleep(REDIS_RECONNECT_DELAY / 1000) | |
| continue | |
| log.error( | |
| 'Redis operation failed after %s retries: %s', | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| e, | |
| ) | |
| raise e from e | |
| return _wrapped | |
| else: | |
| def _wrapped(*args, **kwargs): | |
| for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): | |
| try: | |
| method = getattr(self._master(), item) | |
| return method(*args, **kwargs) | |
| except ( | |
| redis.exceptions.ConnectionError, | |
| redis.exceptions.ReadOnlyError, | |
| ) as e: | |
| if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: | |
| log.debug( | |
| 'Redis sentinel fail-over (%s). Retry %s/%s', | |
| type(e).__name__, | |
| i + 1, | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| ) | |
| if REDIS_RECONNECT_DELAY: | |
| time.sleep(REDIS_RECONNECT_DELAY / 1000) | |
| continue | |
| log.error( | |
| 'Redis operation failed after %s retries: %s', | |
| REDIS_SENTINEL_MAX_RETRY_COUNT, | |
| e, | |
| ) | |
| raise e from e | |
| return _wrapped | |
| def parse_redis_service_url(redis_url): | |
| parsed_url = urlparse(redis_url) | |
| if parsed_url.scheme != 'redis' and parsed_url.scheme != 'rediss': | |
| raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.") | |
| return { | |
| 'username': parsed_url.username or None, | |
| 'password': parsed_url.password or None, | |
| 'service': parsed_url.hostname or 'mymaster', | |
| 'port': parsed_url.port or 6379, | |
| 'db': int(parsed_url.path.lstrip('/') or 0), | |
| } | |
| def get_redis_client(async_mode=False): | |
| try: | |
| return get_redis_connection( | |
| redis_url=REDIS_URL, | |
| redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), | |
| redis_cluster=REDIS_CLUSTER, | |
| async_mode=async_mode, | |
| ) | |
| except Exception as e: | |
| log.debug(f'Failed to get Redis client: {e}') | |
| return None | |
| def get_redis_connection( | |
| redis_url, | |
| redis_sentinels, | |
| redis_cluster=False, | |
| async_mode=False, | |
| decode_responses=True, | |
| ): | |
| cache_key = ( | |
| redis_url, | |
| tuple(redis_sentinels) if redis_sentinels else (), | |
| async_mode, | |
| decode_responses, | |
| ) | |
| if cache_key in _CONNECTION_CACHE: | |
| return _CONNECTION_CACHE[cache_key] | |
| connection = None | |
| connect_timeout_kwargs = ( | |
| {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} if REDIS_SOCKET_CONNECT_TIMEOUT is not None else {} | |
| ) | |
| keepalive_kwargs = {'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {} | |
| health_check_kwargs = {'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL} if REDIS_HEALTH_CHECK_INTERVAL else {} | |
| if async_mode: | |
| import redis.asyncio as redis | |
| # If using sentinel in async mode | |
| if redis_sentinels: | |
| redis_config = parse_redis_service_url(redis_url) | |
| sentinel = redis.sentinel.Sentinel( | |
| redis_sentinels, | |
| port=redis_config['port'], | |
| db=redis_config['db'], | |
| username=redis_config['username'], | |
| password=redis_config['password'], | |
| decode_responses=decode_responses, | |
| socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| connection = SentinelRedisProxy( | |
| sentinel, | |
| redis_config['service'], | |
| async_mode=async_mode, | |
| ) | |
| elif redis_cluster: | |
| if not redis_url: | |
| raise ValueError('Redis URL must be provided for cluster mode.') | |
| return redis.cluster.RedisCluster.from_url( | |
| redis_url, | |
| decode_responses=decode_responses, | |
| **connect_timeout_kwargs, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| elif redis_url: | |
| connection = redis.from_url( | |
| redis_url, | |
| decode_responses=decode_responses, | |
| **connect_timeout_kwargs, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| else: | |
| import redis | |
| if redis_sentinels: | |
| redis_config = parse_redis_service_url(redis_url) | |
| sentinel = redis.sentinel.Sentinel( | |
| redis_sentinels, | |
| port=redis_config['port'], | |
| db=redis_config['db'], | |
| username=redis_config['username'], | |
| password=redis_config['password'], | |
| decode_responses=decode_responses, | |
| socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| connection = SentinelRedisProxy( | |
| sentinel, | |
| redis_config['service'], | |
| async_mode=async_mode, | |
| ) | |
| elif redis_cluster: | |
| if not redis_url: | |
| raise ValueError('Redis URL must be provided for cluster mode.') | |
| return redis.cluster.RedisCluster.from_url( | |
| redis_url, | |
| decode_responses=decode_responses, | |
| **connect_timeout_kwargs, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| elif redis_url: | |
| connection = redis.Redis.from_url( | |
| redis_url, | |
| decode_responses=decode_responses, | |
| **connect_timeout_kwargs, | |
| **keepalive_kwargs, | |
| **health_check_kwargs, | |
| ) | |
| _CONNECTION_CACHE[cache_key] = connection | |
| return connection | |
| def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): | |
| if sentinel_hosts_env: | |
| sentinel_hosts = sentinel_hosts_env.split(',') | |
| sentinel_port = int(sentinel_port_env) | |
| return [(host, sentinel_port) for host in sentinel_hosts] | |
| return [] | |
| def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env): | |
| redis_config = parse_redis_service_url(redis_url) | |
| username = redis_config['username'] or '' | |
| password = redis_config['password'] or '' | |
| auth_part = '' | |
| if username or password: | |
| auth_part = f'{username}:{password}@' | |
| hosts_part = ','.join(f'{host}:{sentinel_port_env}' for host in sentinel_hosts_env.split(',')) | |
| return f'redis+sentinel://{auth_part}{hosts_part}/{redis_config["db"]}/{redis_config["service"]}' | |