| """Redis-backed distributed data structures for WebSocket state management.""" |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import uuid |
|
|
| import pycrdt as Y |
| from open_webui.utils.redis import get_redis_connection |
| from open_webui.env import REDIS_KEY_PREFIX |
|
|
| YDOC_KEY_PREFIX = f'{REDIS_KEY_PREFIX}:ydoc:documents' |
|
|
|
|
| class RedisLock: |
| """Distributed lock backed by a Redis SET with NX/EX semantics.""" |
|
|
| def __init__( |
| self, |
| redis_url, |
| lock_name, |
| timeout_secs, |
| redis_sentinels=[], |
| redis_cluster=False, |
| ): |
| self.lock_name = lock_name |
| self.lock_id = str(uuid.uuid4()) |
| self.timeout_secs = timeout_secs |
| self.lock_obtained = False |
| self.redis = get_redis_connection( |
| redis_url, |
| redis_sentinels, |
| redis_cluster=redis_cluster, |
| decode_responses=True, |
| ) |
|
|
| def aquire_lock(self): |
| |
| self.lock_obtained = self.redis.set(self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs) |
| return self.lock_obtained |
|
|
| def renew_lock(self): |
| |
| return self.redis.set(self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs) |
|
|
| def release_lock(self): |
| lock_value = self.redis.get(self.lock_name) |
| if lock_value and lock_value == self.lock_id: |
| self.redis.delete(self.lock_name) |
|
|
|
|
| class RedisDict: |
| def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False): |
| self.name = name |
| |
| |
| |
| self._last_signature: str | None = None |
| self.redis = get_redis_connection( |
| redis_url, |
| redis_sentinels, |
| redis_cluster=redis_cluster, |
| decode_responses=True, |
| ) |
|
|
| def __setitem__(self, key, value): |
| serialized_value = json.dumps(value) |
| self.redis.hset(self.name, key, serialized_value) |
|
|
| def __getitem__(self, key): |
| value = self.redis.hget(self.name, key) |
| if value is None: |
| raise KeyError(key) |
| return json.loads(value) |
|
|
| def __delitem__(self, key): |
| result = self.redis.hdel(self.name, key) |
| if result == 0: |
| raise KeyError(key) |
|
|
| def __contains__(self, key): |
| return self.redis.hexists(self.name, key) |
|
|
| def __len__(self): |
| return self.redis.hlen(self.name) |
|
|
| def keys(self): |
| return self.redis.hkeys(self.name) |
|
|
| def values(self): |
| return [json.loads(v) for v in self.redis.hvals(self.name)] |
|
|
| def items(self): |
| return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] |
|
|
| def set(self, mapping: dict): |
| if not mapping: |
| self.redis.delete(self.name) |
| self._last_signature = None |
| return |
|
|
| |
| serialized = {k: json.dumps(v) for k, v in mapping.items()} |
|
|
| |
| |
| |
| |
| signature = hashlib.sha256(json.dumps(serialized, sort_keys=True).encode()).hexdigest() |
| if signature == self._last_signature: |
| return |
|
|
| |
| |
| existing_keys = set(self.redis.hkeys(self.name)) |
| new_keys = set(mapping.keys()) |
| keys_to_remove = existing_keys - new_keys |
|
|
| |
| |
| |
| self.redis.hset(self.name, mapping=serialized) |
| if keys_to_remove: |
| self.redis.hdel(self.name, *keys_to_remove) |
|
|
| self._last_signature = signature |
|
|
| def get(self, key, default=None): |
| try: |
| return self[key] |
| except KeyError: |
| return default |
|
|
| def clear(self): |
| self.redis.delete(self.name) |
| self._last_signature = None |
|
|
| def update(self, other=None, **kwargs): |
| if other is not None: |
| for k, v in other.items() if hasattr(other, 'items') else other: |
| self[k] = v |
| for k, v in kwargs.items(): |
| self[k] = v |
|
|
| def setdefault(self, key, default=None): |
| if key not in self: |
| self[key] = default |
| return self[key] |
|
|
|
|
| class YdocManager: |
| COMPACTION_THRESHOLD = 500 |
|
|
| def __init__( |
| self, |
| redis=None, |
| redis_key_prefix: str = YDOC_KEY_PREFIX, |
| ): |
| self._updates = {} |
| self._users = {} |
| self._redis = redis |
| self._redis_key_prefix = redis_key_prefix |
|
|
| async def append_to_updates(self, document_id: str, update: bytes): |
| document_id = document_id.replace(':', '_') |
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:updates' |
| await self._redis.rpush(redis_key, json.dumps(list(update))) |
| list_len = await self._redis.llen(redis_key) |
| if list_len >= self.COMPACTION_THRESHOLD: |
| await self._compact_updates_redis(document_id) |
| else: |
| if document_id not in self._updates: |
| self._updates[document_id] = [] |
| self._updates[document_id].append(update) |
| if len(self._updates[document_id]) >= self.COMPACTION_THRESHOLD: |
| self._compact_updates_memory(document_id) |
|
|
| async def _compact_updates_redis(self, document_id: str): |
| """Rolling compaction: squash oldest half into one snapshot.""" |
| redis_key = f'{self._redis_key_prefix}:{document_id}:updates' |
| all_updates = await self._redis.lrange(redis_key, 0, -1) |
| if len(all_updates) <= 1: |
| return |
| mid = len(all_updates) // 2 |
| ydoc = Y.Doc() |
| for raw in all_updates[:mid]: |
| ydoc.apply_update(bytes(json.loads(raw))) |
| snapshot = json.dumps(list(ydoc.get_update())) |
| pipe = self._redis.pipeline() |
| pipe.delete(redis_key) |
| pipe.rpush(redis_key, snapshot, *all_updates[mid:]) |
| await pipe.execute() |
|
|
| def _compact_updates_memory(self, document_id: str): |
| """Rolling compaction: squash oldest half into one snapshot.""" |
| updates = self._updates.get(document_id, []) |
| if len(updates) <= 1: |
| return |
| mid = len(updates) // 2 |
| ydoc = Y.Doc() |
| for update in updates[:mid]: |
| ydoc.apply_update(bytes(update)) |
| self._updates[document_id] = [ydoc.get_update()] + updates[mid:] |
|
|
| async def get_updates(self, document_id: str) -> list[bytes]: |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:updates' |
| updates = await self._redis.lrange(redis_key, 0, -1) |
| return [bytes(json.loads(update)) for update in updates] |
| else: |
| return self._updates.get(document_id, []) |
|
|
| async def document_exists(self, document_id: str) -> bool: |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:updates' |
| return await self._redis.exists(redis_key) > 0 |
| else: |
| return document_id in self._updates |
|
|
| async def get_users(self, document_id: str) -> list[str]: |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:users' |
| users = await self._redis.smembers(redis_key) |
| return list(users) |
| else: |
| return self._users.get(document_id, []) |
|
|
| async def add_user(self, document_id: str, user_id: str): |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:users' |
| await self._redis.sadd(redis_key, user_id) |
| |
| |
| |
| session_key = f'{self._redis_key_prefix}:session:{user_id}:documents' |
| await self._redis.sadd(session_key, document_id) |
| else: |
| if document_id not in self._users: |
| self._users[document_id] = set() |
| self._users[document_id].add(user_id) |
|
|
| async def remove_user(self, document_id: str, user_id: str): |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:users' |
| await self._redis.srem(redis_key, user_id) |
| |
| session_key = f'{self._redis_key_prefix}:session:{user_id}:documents' |
| await self._redis.srem(session_key, document_id) |
| else: |
| if document_id in self._users and user_id in self._users[document_id]: |
| self._users[document_id].remove(user_id) |
|
|
| async def remove_user_from_all_documents(self, user_id: str): |
| if self._redis: |
| |
| |
| |
| |
| session_key = f'{self._redis_key_prefix}:session:{user_id}:documents' |
| document_ids = await self._redis.smembers(session_key) |
|
|
| for document_id in document_ids: |
| users_key = f'{self._redis_key_prefix}:{document_id}:users' |
| await self._redis.srem(users_key, user_id) |
|
|
| if len(await self.get_users(document_id)) == 0: |
| await self.clear_document(document_id) |
|
|
| |
| await self._redis.delete(session_key) |
|
|
| else: |
| for document_id in list(self._users.keys()): |
| if user_id in self._users[document_id]: |
| self._users[document_id].remove(user_id) |
| if not self._users[document_id]: |
| del self._users[document_id] |
|
|
| await self.clear_document(document_id) |
|
|
| async def clear_document(self, document_id: str): |
| document_id = document_id.replace(':', '_') |
|
|
| if self._redis: |
| redis_key = f'{self._redis_key_prefix}:{document_id}:updates' |
| await self._redis.delete(redis_key) |
| redis_users_key = f'{self._redis_key_prefix}:{document_id}:users' |
| await self._redis.delete(redis_users_key) |
| else: |
| if document_id in self._updates: |
| del self._updates[document_id] |
| if document_id in self._users: |
| del self._users[document_id] |
|
|