| | import logging |
| | from typing import Optional, Tuple, List, Dict, Any |
| |
|
| | from open_webui.config import ( |
| | MILVUS_URI, |
| | MILVUS_TOKEN, |
| | MILVUS_DB, |
| | MILVUS_COLLECTION_PREFIX, |
| | MILVUS_INDEX_TYPE, |
| | MILVUS_METRIC_TYPE, |
| | MILVUS_HNSW_M, |
| | MILVUS_HNSW_EFCONSTRUCTION, |
| | MILVUS_IVF_FLAT_NLIST, |
| | ) |
| | from open_webui.retrieval.vector.main import ( |
| | GetResult, |
| | SearchResult, |
| | VectorDBBase, |
| | VectorItem, |
| | ) |
| | from pymilvus import ( |
| | connections, |
| | utility, |
| | Collection, |
| | CollectionSchema, |
| | FieldSchema, |
| | DataType, |
| | ) |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | RESOURCE_ID_FIELD = "resource_id" |
| |
|
| |
|
| | class MilvusClient(VectorDBBase): |
| | def __init__(self): |
| | |
| | self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") |
| | connections.connect( |
| | alias="default", |
| | uri=MILVUS_URI, |
| | token=MILVUS_TOKEN, |
| | db_name=MILVUS_DB, |
| | ) |
| |
|
| | |
| | self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" |
| | self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" |
| | self.FILE_COLLECTION = f"{self.collection_prefix}_files" |
| | self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" |
| | self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" |
| | self.shared_collections = [ |
| | self.MEMORY_COLLECTION, |
| | self.KNOWLEDGE_COLLECTION, |
| | self.FILE_COLLECTION, |
| | self.WEB_SEARCH_COLLECTION, |
| | self.HASH_BASED_COLLECTION, |
| | ] |
| |
|
| | def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]: |
| | """ |
| | Maps the traditional collection name to multi-tenant collection and resource ID. |
| | |
| | WARNING: This mapping relies on current Open WebUI naming conventions for |
| | collection names. If Open WebUI changes how it generates collection names |
| | (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash |
| | formats), this mapping will break and route data to incorrect collections. |
| | POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT |
| | DATA MAPPING INSIDE THE DATABASE. |
| | """ |
| | resource_id = collection_name |
| |
|
| | if collection_name.startswith("user-memory-"): |
| | return self.MEMORY_COLLECTION, resource_id |
| | elif collection_name.startswith("file-"): |
| | return self.FILE_COLLECTION, resource_id |
| | elif collection_name.startswith("web-search-"): |
| | return self.WEB_SEARCH_COLLECTION, resource_id |
| | elif len(collection_name) == 63 and all( |
| | c in "0123456789abcdef" for c in collection_name |
| | ): |
| | return self.HASH_BASED_COLLECTION, resource_id |
| | else: |
| | return self.KNOWLEDGE_COLLECTION, resource_id |
| |
|
| | def _create_shared_collection(self, mt_collection_name: str, dimension: int): |
| | fields = [ |
| | FieldSchema( |
| | name="id", |
| | dtype=DataType.VARCHAR, |
| | is_primary=True, |
| | auto_id=False, |
| | max_length=36, |
| | ), |
| | FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), |
| | FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), |
| | FieldSchema(name="metadata", dtype=DataType.JSON), |
| | FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255), |
| | ] |
| | schema = CollectionSchema(fields, "Shared collection for multi-tenancy") |
| | collection = Collection(mt_collection_name, schema) |
| |
|
| | index_params = { |
| | "metric_type": MILVUS_METRIC_TYPE, |
| | "index_type": MILVUS_INDEX_TYPE, |
| | "params": {}, |
| | } |
| | if MILVUS_INDEX_TYPE == "HNSW": |
| | index_params["params"] = { |
| | "M": MILVUS_HNSW_M, |
| | "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, |
| | } |
| | elif MILVUS_INDEX_TYPE == "IVF_FLAT": |
| | index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} |
| |
|
| | collection.create_index("vector", index_params) |
| | collection.create_index(RESOURCE_ID_FIELD) |
| | log.info(f"Created shared collection: {mt_collection_name}") |
| | return collection |
| |
|
| | def _ensure_collection(self, mt_collection_name: str, dimension: int): |
| | if not utility.has_collection(mt_collection_name): |
| | self._create_shared_collection(mt_collection_name, dimension) |
| |
|
| | def has_collection(self, collection_name: str) -> bool: |
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | if not utility.has_collection(mt_collection): |
| | return False |
| |
|
| | collection = Collection(mt_collection) |
| | collection.load() |
| | res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1) |
| | return len(res) > 0 |
| |
|
| | def upsert(self, collection_name: str, items: List[VectorItem]): |
| | if not items: |
| | return |
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | dimension = len(items[0]["vector"]) |
| | self._ensure_collection(mt_collection, dimension) |
| | collection = Collection(mt_collection) |
| |
|
| | entities = [ |
| | { |
| | "id": item["id"], |
| | "vector": item["vector"], |
| | "text": item["text"], |
| | "metadata": item["metadata"], |
| | RESOURCE_ID_FIELD: resource_id, |
| | } |
| | for item in items |
| | ] |
| | collection.insert(entities) |
| |
|
| | def search( |
| | self, |
| | collection_name: str, |
| | vectors: List[List[float]], |
| | filter: Optional[Dict] = None, |
| | limit: int = 10, |
| | ) -> Optional[SearchResult]: |
| | if not vectors: |
| | return None |
| |
|
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | if not utility.has_collection(mt_collection): |
| | return None |
| |
|
| | collection = Collection(mt_collection) |
| | collection.load() |
| |
|
| | search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} |
| | results = collection.search( |
| | data=vectors, |
| | anns_field="vector", |
| | param=search_params, |
| | limit=limit, |
| | expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", |
| | output_fields=["id", "text", "metadata"], |
| | ) |
| |
|
| | ids, documents, metadatas, distances = [], [], [], [] |
| | for hits in results: |
| | batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] |
| | for hit in hits: |
| | batch_ids.append(hit.entity.get("id")) |
| | batch_docs.append(hit.entity.get("text")) |
| | batch_metadatas.append(hit.entity.get("metadata")) |
| | batch_dists.append(hit.distance) |
| | ids.append(batch_ids) |
| | documents.append(batch_docs) |
| | metadatas.append(batch_metadatas) |
| | distances.append(batch_dists) |
| |
|
| | return SearchResult( |
| | ids=ids, documents=documents, metadatas=metadatas, distances=distances |
| | ) |
| |
|
| | def delete( |
| | self, |
| | collection_name: str, |
| | ids: Optional[List[str]] = None, |
| | filter: Optional[Dict[str, Any]] = None, |
| | ): |
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | if not utility.has_collection(mt_collection): |
| | return |
| |
|
| | collection = Collection(mt_collection) |
| |
|
| | |
| | expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] |
| | if ids: |
| | |
| | id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) |
| | expr.append(f"id in [{id_list_str}]") |
| |
|
| | if filter: |
| | for key, value in filter.items(): |
| | expr.append(f"metadata['{key}'] == '{value}'") |
| |
|
| | collection.delete(" and ".join(expr)) |
| |
|
| | def reset(self): |
| | for collection_name in self.shared_collections: |
| | if utility.has_collection(collection_name): |
| | utility.drop_collection(collection_name) |
| |
|
| | def delete_collection(self, collection_name: str): |
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | if not utility.has_collection(mt_collection): |
| | return |
| |
|
| | collection = Collection(mt_collection) |
| | collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") |
| |
|
| | def query( |
| | self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None |
| | ) -> Optional[GetResult]: |
| | mt_collection, resource_id = self._get_collection_and_resource_id( |
| | collection_name |
| | ) |
| | if not utility.has_collection(mt_collection): |
| | return None |
| |
|
| | collection = Collection(mt_collection) |
| | collection.load() |
| |
|
| | expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] |
| | if filter: |
| | for key, value in filter.items(): |
| | if isinstance(value, str): |
| | expr.append(f"metadata['{key}'] == '{value}'") |
| | else: |
| | expr.append(f"metadata['{key}'] == {value}") |
| |
|
| | iterator = collection.query_iterator( |
| | expr=" and ".join(expr), |
| | output_fields=["id", "text", "metadata"], |
| | limit=limit if limit else -1, |
| | ) |
| |
|
| | all_results = [] |
| | while True: |
| | batch = iterator.next() |
| | if not batch: |
| | iterator.close() |
| | break |
| | all_results.extend(batch) |
| |
|
| | ids = [res["id"] for res in all_results] |
| | documents = [res["text"] for res in all_results] |
| | metadatas = [res["metadata"] for res in all_results] |
| |
|
| | return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) |
| |
|
| | def get(self, collection_name: str) -> Optional[GetResult]: |
| | return self.query(collection_name, filter={}, limit=None) |
| |
|
| | def insert(self, collection_name: str, items: List[VectorItem]): |
| | return self.upsert(collection_name, items) |
| |
|