| | import logging |
| | from typing import Optional, Tuple, List, Dict, Any |
| | from urllib.parse import urlparse |
| |
|
| | import grpc |
| | from open_webui.config import ( |
| | QDRANT_API_KEY, |
| | QDRANT_GRPC_PORT, |
| | QDRANT_ON_DISK, |
| | QDRANT_PREFER_GRPC, |
| | QDRANT_URI, |
| | QDRANT_COLLECTION_PREFIX, |
| | QDRANT_TIMEOUT, |
| | QDRANT_HNSW_M, |
| | ) |
| | from open_webui.retrieval.vector.main import ( |
| | GetResult, |
| | SearchResult, |
| | VectorDBBase, |
| | VectorItem, |
| | ) |
| | from qdrant_client import QdrantClient as Qclient |
| | from qdrant_client.http.exceptions import UnexpectedResponse |
| | from qdrant_client.http.models import PointStruct |
| | from qdrant_client.models import models |
| |
|
| | NO_LIMIT = 999999999 |
| | TENANT_ID_FIELD = "tenant_id" |
| | DEFAULT_DIMENSION = 384 |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | def _tenant_filter(tenant_id: str) -> models.FieldCondition: |
| | return models.FieldCondition( |
| | key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) |
| | ) |
| |
|
| |
|
| | def _metadata_filter(key: str, value: Any) -> models.FieldCondition: |
| | return models.FieldCondition( |
| | key=f"metadata.{key}", match=models.MatchValue(value=value) |
| | ) |
| |
|
| |
|
| | class QdrantClient(VectorDBBase): |
| | def __init__(self): |
| | self.collection_prefix = QDRANT_COLLECTION_PREFIX |
| | self.QDRANT_URI = QDRANT_URI |
| | self.QDRANT_API_KEY = QDRANT_API_KEY |
| | self.QDRANT_ON_DISK = QDRANT_ON_DISK |
| | self.PREFER_GRPC = QDRANT_PREFER_GRPC |
| | self.GRPC_PORT = QDRANT_GRPC_PORT |
| | self.QDRANT_TIMEOUT = QDRANT_TIMEOUT |
| | self.QDRANT_HNSW_M = QDRANT_HNSW_M |
| |
|
| | if not self.QDRANT_URI: |
| | raise ValueError( |
| | "QDRANT_URI is not set. Please configure it in the environment variables." |
| | ) |
| |
|
| | |
| | parsed = urlparse(self.QDRANT_URI) |
| | host = parsed.hostname or self.QDRANT_URI |
| | http_port = parsed.port or 6333 |
| |
|
| | self.client = ( |
| | Qclient( |
| | host=host, |
| | port=http_port, |
| | grpc_port=self.GRPC_PORT, |
| | prefer_grpc=self.PREFER_GRPC, |
| | api_key=self.QDRANT_API_KEY, |
| | timeout=self.QDRANT_TIMEOUT, |
| | ) |
| | if self.PREFER_GRPC |
| | else Qclient( |
| | url=self.QDRANT_URI, |
| | api_key=self.QDRANT_API_KEY, |
| | timeout=self.QDRANT_TIMEOUT, |
| | ) |
| | ) |
| |
|
| | |
| | self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" |
| | self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" |
| | self.FILE_COLLECTION = f"{self.collection_prefix}_files" |
| | self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search" |
| | self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" |
| |
|
| | def _result_to_get_result(self, points) -> GetResult: |
| | ids, documents, metadatas = [], [], [] |
| | for point in points: |
| | payload = point.payload |
| | ids.append(point.id) |
| | documents.append(payload["text"]) |
| | metadatas.append(payload["metadata"]) |
| | return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) |
| |
|
| | def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: |
| | """ |
| | Maps the traditional collection name to multi-tenant collection and tenant ID. |
| | |
| | Returns: |
| | tuple: (collection_name, tenant_id) |
| | |
| | WARNING: This mapping relies on current Open WebUI naming conventions for |
| | collection names. If Open WebUI changes how it generates collection names |
| | (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash |
| | formats), this mapping will break and route data to incorrect collections. |
| | POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT |
| | DATA MAPPING INSIDE THE DATABASE. |
| | """ |
| | |
| | tenant_id = collection_name |
| |
|
| | if collection_name.startswith("user-memory-"): |
| | return self.MEMORY_COLLECTION, tenant_id |
| |
|
| | |
| | elif collection_name.startswith("file-"): |
| | return self.FILE_COLLECTION, tenant_id |
| |
|
| | |
| | elif collection_name.startswith("web-search-"): |
| | return self.WEB_SEARCH_COLLECTION, tenant_id |
| |
|
| | |
| | elif len(collection_name) == 63 and all( |
| | c in "0123456789abcdef" for c in collection_name |
| | ): |
| | return self.HASH_BASED_COLLECTION, tenant_id |
| |
|
| | else: |
| | return self.KNOWLEDGE_COLLECTION, tenant_id |
| |
|
| | def _create_multi_tenant_collection( |
| | self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION |
| | ): |
| | """ |
| | Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. |
| | """ |
| | self.client.create_collection( |
| | collection_name=mt_collection_name, |
| | vectors_config=models.VectorParams( |
| | size=dimension, |
| | distance=models.Distance.COSINE, |
| | on_disk=self.QDRANT_ON_DISK, |
| | ), |
| | |
| | |
| | hnsw_config=models.HnswConfigDiff( |
| | payload_m=self.QDRANT_HNSW_M, |
| | m=0, |
| | ), |
| | ) |
| | log.info( |
| | f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" |
| | ) |
| |
|
| | self.client.create_payload_index( |
| | collection_name=mt_collection_name, |
| | field_name=TENANT_ID_FIELD, |
| | field_schema=models.KeywordIndexParams( |
| | type=models.KeywordIndexType.KEYWORD, |
| | is_tenant=True, |
| | on_disk=self.QDRANT_ON_DISK, |
| | ), |
| | ) |
| |
|
| | for field in ("metadata.hash", "metadata.file_id"): |
| | self.client.create_payload_index( |
| | collection_name=mt_collection_name, |
| | field_name=field, |
| | field_schema=models.KeywordIndexParams( |
| | type=models.KeywordIndexType.KEYWORD, |
| | on_disk=self.QDRANT_ON_DISK, |
| | ), |
| | ) |
| |
|
| | def _create_points( |
| | self, items: List[VectorItem], tenant_id: str |
| | ) -> List[PointStruct]: |
| | """ |
| | Create point structs from vector items with tenant ID. |
| | """ |
| | return [ |
| | PointStruct( |
| | id=item["id"], |
| | vector=item["vector"], |
| | payload={ |
| | "text": item["text"], |
| | "metadata": item["metadata"], |
| | TENANT_ID_FIELD: tenant_id, |
| | }, |
| | ) |
| | for item in items |
| | ] |
| |
|
| | def _ensure_collection( |
| | self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION |
| | ): |
| | """ |
| | Ensure the collection exists and payload indexes are created for tenant_id and metadata fields. |
| | """ |
| | if not self.client.collection_exists(collection_name=mt_collection_name): |
| | self._create_multi_tenant_collection(mt_collection_name, dimension) |
| |
|
| | def has_collection(self, collection_name: str) -> bool: |
| | """ |
| | Check if a logical collection exists by checking for any points with the tenant ID. |
| | """ |
| | if not self.client: |
| | return False |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | return False |
| | tenant_filter = _tenant_filter(tenant_id) |
| | count_result = self.client.count( |
| | collection_name=mt_collection, |
| | count_filter=models.Filter(must=[tenant_filter]), |
| | ) |
| | return count_result.count > 0 |
| |
|
| | def delete( |
| | self, |
| | collection_name: str, |
| | ids: Optional[List[str]] = None, |
| | filter: Optional[Dict[str, Any]] = None, |
| | ): |
| | """ |
| | Delete vectors by ID or filter from a collection with tenant isolation. |
| | """ |
| | if not self.client: |
| | return None |
| |
|
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") |
| | return None |
| |
|
| | must_conditions = [_tenant_filter(tenant_id)] |
| | should_conditions = [] |
| | if ids: |
| | should_conditions = [_metadata_filter("id", id_value) for id_value in ids] |
| | elif filter: |
| | must_conditions += [_metadata_filter(k, v) for k, v in filter.items()] |
| |
|
| | return self.client.delete( |
| | collection_name=mt_collection, |
| | points_selector=models.FilterSelector( |
| | filter=models.Filter(must=must_conditions, should=should_conditions) |
| | ), |
| | ) |
| |
|
| | def search( |
| | self, |
| | collection_name: str, |
| | vectors: List[List[float | int]], |
| | filter: Optional[Dict] = None, |
| | limit: int = 10, |
| | ) -> Optional[SearchResult]: |
| | """ |
| | Search for the nearest neighbor items based on the vectors with tenant isolation. |
| | """ |
| | if not self.client or not vectors: |
| | return None |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | log.debug(f"Collection {mt_collection} doesn't exist, search returns None") |
| | return None |
| |
|
| | tenant_filter = _tenant_filter(tenant_id) |
| | query_response = self.client.query_points( |
| | collection_name=mt_collection, |
| | query=vectors[0], |
| | limit=limit, |
| | query_filter=models.Filter(must=[tenant_filter]), |
| | ) |
| | get_result = self._result_to_get_result(query_response.points) |
| | return SearchResult( |
| | ids=get_result.ids, |
| | documents=get_result.documents, |
| | metadatas=get_result.metadatas, |
| | distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]], |
| | ) |
| |
|
| | def query( |
| | self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None |
| | ): |
| | """ |
| | Query points with filters and tenant isolation. |
| | """ |
| | if not self.client: |
| | return None |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | log.debug(f"Collection {mt_collection} doesn't exist, query returns None") |
| | return None |
| | if limit is None: |
| | limit = NO_LIMIT |
| | tenant_filter = _tenant_filter(tenant_id) |
| | field_conditions = [_metadata_filter(k, v) for k, v in filter.items()] |
| | combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) |
| | points = self.client.scroll( |
| | collection_name=mt_collection, |
| | scroll_filter=combined_filter, |
| | limit=limit, |
| | ) |
| | return self._result_to_get_result(points[0]) |
| |
|
| | def get(self, collection_name: str) -> Optional[GetResult]: |
| | """ |
| | Get all items in a collection with tenant isolation. |
| | """ |
| | if not self.client: |
| | return None |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | log.debug(f"Collection {mt_collection} doesn't exist, get returns None") |
| | return None |
| | tenant_filter = _tenant_filter(tenant_id) |
| | points = self.client.scroll( |
| | collection_name=mt_collection, |
| | scroll_filter=models.Filter(must=[tenant_filter]), |
| | limit=NO_LIMIT, |
| | ) |
| | return self._result_to_get_result(points[0]) |
| |
|
| | def upsert(self, collection_name: str, items: List[VectorItem]): |
| | """ |
| | Upsert items with tenant ID. |
| | """ |
| | if not self.client or not items: |
| | return None |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | dimension = len(items[0]["vector"]) |
| | self._ensure_collection(mt_collection, dimension) |
| | points = self._create_points(items, tenant_id) |
| | self.client.upload_points(mt_collection, points) |
| | return None |
| |
|
| | def insert(self, collection_name: str, items: List[VectorItem]): |
| | """ |
| | Insert items with tenant ID. |
| | """ |
| | return self.upsert(collection_name, items) |
| |
|
| | def reset(self): |
| | """ |
| | Reset the database by deleting all collections. |
| | """ |
| | if not self.client: |
| | return None |
| | for collection in self.client.get_collections().collections: |
| | if collection.name.startswith(self.collection_prefix): |
| | self.client.delete_collection(collection_name=collection.name) |
| |
|
| | def delete_collection(self, collection_name: str): |
| | """ |
| | Delete a collection. |
| | """ |
| | if not self.client: |
| | return None |
| | mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
| | if not self.client.collection_exists(collection_name=mt_collection): |
| | log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") |
| | return None |
| | self.client.delete( |
| | collection_name=mt_collection, |
| | points_selector=models.FilterSelector( |
| | filter=models.Filter(must=[_tenant_filter(tenant_id)]) |
| | ), |
| | ) |
| |
|