| | import json |
| | import uuid |
| | from open_webui.utils.redis import get_redis_connection |
| | from open_webui.env import REDIS_KEY_PREFIX |
| | from typing import Optional, List, Tuple |
| | import pycrdt as Y |
| |
|
| |
|
| | class RedisLock: |
| | def __init__( |
| | self, |
| | redis_url, |
| | lock_name, |
| | timeout_secs, |
| | redis_sentinels=[], |
| | redis_cluster=False, |
| | ): |
| |
|
| | self.lock_name = lock_name |
| | self.lock_id = str(uuid.uuid4()) |
| | self.timeout_secs = timeout_secs |
| | self.lock_obtained = False |
| | self.redis = get_redis_connection( |
| | redis_url, |
| | redis_sentinels, |
| | redis_cluster=redis_cluster, |
| | decode_responses=True, |
| | ) |
| |
|
| | def aquire_lock(self): |
| | |
| | self.lock_obtained = self.redis.set( |
| | self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs |
| | ) |
| | return self.lock_obtained |
| |
|
| | def renew_lock(self): |
| | |
| | return self.redis.set( |
| | self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs |
| | ) |
| |
|
| | def release_lock(self): |
| | lock_value = self.redis.get(self.lock_name) |
| | if lock_value and lock_value == self.lock_id: |
| | self.redis.delete(self.lock_name) |
| |
|
| |
|
| | class RedisDict: |
| | def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False): |
| | self.name = name |
| | self.redis = get_redis_connection( |
| | redis_url, |
| | redis_sentinels, |
| | redis_cluster=redis_cluster, |
| | decode_responses=True, |
| | ) |
| |
|
| | def __setitem__(self, key, value): |
| | serialized_value = json.dumps(value) |
| | self.redis.hset(self.name, key, serialized_value) |
| |
|
| | def __getitem__(self, key): |
| | value = self.redis.hget(self.name, key) |
| | if value is None: |
| | raise KeyError(key) |
| | return json.loads(value) |
| |
|
| | def __delitem__(self, key): |
| | result = self.redis.hdel(self.name, key) |
| | if result == 0: |
| | raise KeyError(key) |
| |
|
| | def __contains__(self, key): |
| | return self.redis.hexists(self.name, key) |
| |
|
| | def __len__(self): |
| | return self.redis.hlen(self.name) |
| |
|
| | def keys(self): |
| | return self.redis.hkeys(self.name) |
| |
|
| | def values(self): |
| | return [json.loads(v) for v in self.redis.hvals(self.name)] |
| |
|
| | def items(self): |
| | return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] |
| |
|
| | def set(self, mapping: dict): |
| | pipe = self.redis.pipeline() |
| |
|
| | pipe.delete(self.name) |
| | if mapping: |
| | pipe.hset(self.name, mapping={k: json.dumps(v) for k, v in mapping.items()}) |
| |
|
| | pipe.execute() |
| |
|
| | def get(self, key, default=None): |
| | try: |
| | return self[key] |
| | except KeyError: |
| | return default |
| |
|
| | def clear(self): |
| | self.redis.delete(self.name) |
| |
|
| | def update(self, other=None, **kwargs): |
| | if other is not None: |
| | for k, v in other.items() if hasattr(other, "items") else other: |
| | self[k] = v |
| | for k, v in kwargs.items(): |
| | self[k] = v |
| |
|
| | def setdefault(self, key, default=None): |
| | if key not in self: |
| | self[key] = default |
| | return self[key] |
| |
|
| |
|
| | class YdocManager: |
| | COMPACTION_THRESHOLD = 500 |
| |
|
| | def __init__( |
| | self, |
| | redis=None, |
| | redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents", |
| | ): |
| | self._updates = {} |
| | self._users = {} |
| | self._redis = redis |
| | self._redis_key_prefix = redis_key_prefix |
| |
|
| | async def append_to_updates(self, document_id: str, update: bytes): |
| | document_id = document_id.replace(":", "_") |
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:updates" |
| | await self._redis.rpush(redis_key, json.dumps(list(update))) |
| | list_len = await self._redis.llen(redis_key) |
| | if list_len >= self.COMPACTION_THRESHOLD: |
| | await self._compact_updates_redis(document_id) |
| | else: |
| | if document_id not in self._updates: |
| | self._updates[document_id] = [] |
| | self._updates[document_id].append(update) |
| | if len(self._updates[document_id]) >= self.COMPACTION_THRESHOLD: |
| | self._compact_updates_memory(document_id) |
| |
|
| | async def _compact_updates_redis(self, document_id: str): |
| | """Rolling compaction: squash oldest half into one snapshot.""" |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:updates" |
| | all_updates = await self._redis.lrange(redis_key, 0, -1) |
| | if len(all_updates) <= 1: |
| | return |
| | mid = len(all_updates) // 2 |
| | ydoc = Y.Doc() |
| | for raw in all_updates[:mid]: |
| | ydoc.apply_update(bytes(json.loads(raw))) |
| | snapshot = json.dumps(list(ydoc.get_update())) |
| | pipe = self._redis.pipeline() |
| | pipe.delete(redis_key) |
| | pipe.rpush(redis_key, snapshot, *all_updates[mid:]) |
| | await pipe.execute() |
| |
|
| | def _compact_updates_memory(self, document_id: str): |
| | """Rolling compaction: squash oldest half into one snapshot.""" |
| | updates = self._updates.get(document_id, []) |
| | if len(updates) <= 1: |
| | return |
| | mid = len(updates) // 2 |
| | ydoc = Y.Doc() |
| | for update in updates[:mid]: |
| | ydoc.apply_update(bytes(update)) |
| | self._updates[document_id] = [ydoc.get_update()] + updates[mid:] |
| |
|
| | async def get_updates(self, document_id: str) -> List[bytes]: |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:updates" |
| | updates = await self._redis.lrange(redis_key, 0, -1) |
| | return [bytes(json.loads(update)) for update in updates] |
| | else: |
| | return self._updates.get(document_id, []) |
| |
|
| | async def document_exists(self, document_id: str) -> bool: |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:updates" |
| | return await self._redis.exists(redis_key) > 0 |
| | else: |
| | return document_id in self._updates |
| |
|
| | async def get_users(self, document_id: str) -> List[str]: |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:users" |
| | users = await self._redis.smembers(redis_key) |
| | return list(users) |
| | else: |
| | return self._users.get(document_id, []) |
| |
|
| | async def add_user(self, document_id: str, user_id: str): |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:users" |
| | await self._redis.sadd(redis_key, user_id) |
| | else: |
| | if document_id not in self._users: |
| | self._users[document_id] = set() |
| | self._users[document_id].add(user_id) |
| |
|
| | async def remove_user(self, document_id: str, user_id: str): |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:users" |
| | await self._redis.srem(redis_key, user_id) |
| | else: |
| | if document_id in self._users and user_id in self._users[document_id]: |
| | self._users[document_id].remove(user_id) |
| |
|
| | async def remove_user_from_all_documents(self, user_id: str): |
| | if self._redis: |
| | keys = [] |
| | async for key in self._redis.scan_iter( |
| | match=f"{self._redis_key_prefix}:*", count=100 |
| | ): |
| | keys.append(key) |
| | for key in keys: |
| | if key.endswith(":users"): |
| | await self._redis.srem(key, user_id) |
| |
|
| | document_id = key.split(":")[-2] |
| | if len(await self.get_users(document_id)) == 0: |
| | await self.clear_document(document_id) |
| |
|
| | else: |
| | for document_id in list(self._users.keys()): |
| | if user_id in self._users[document_id]: |
| | self._users[document_id].remove(user_id) |
| | if not self._users[document_id]: |
| | del self._users[document_id] |
| |
|
| | await self.clear_document(document_id) |
| |
|
| | async def clear_document(self, document_id: str): |
| | document_id = document_id.replace(":", "_") |
| |
|
| | if self._redis: |
| | redis_key = f"{self._redis_key_prefix}:{document_id}:updates" |
| | await self._redis.delete(redis_key) |
| | redis_users_key = f"{self._redis_key_prefix}:{document_id}:users" |
| | await self._redis.delete(redis_users_key) |
| | else: |
| | if document_id in self._updates: |
| | del self._updates[document_id] |
| | if document_id in self._users: |
| | del self._users[document_id] |
| |
|