| | from pymilvus import MilvusClient as Client |
| | from pymilvus import FieldSchema, DataType |
| | from pymilvus import connections, Collection |
| |
|
| | import json |
| | import logging |
| | from typing import Optional |
| |
|
| | from open_webui.retrieval.vector.utils import process_metadata |
| | from open_webui.retrieval.vector.main import ( |
| | VectorDBBase, |
| | VectorItem, |
| | SearchResult, |
| | GetResult, |
| | ) |
| | from open_webui.config import ( |
| | MILVUS_URI, |
| | MILVUS_DB, |
| | MILVUS_TOKEN, |
| | MILVUS_INDEX_TYPE, |
| | MILVUS_METRIC_TYPE, |
| | MILVUS_HNSW_M, |
| | MILVUS_HNSW_EFCONSTRUCTION, |
| | MILVUS_IVF_FLAT_NLIST, |
| | MILVUS_DISKANN_MAX_DEGREE, |
| | MILVUS_DISKANN_SEARCH_LIST_SIZE, |
| | ) |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | class MilvusClient(VectorDBBase): |
| | def __init__(self): |
| | self.collection_prefix = "open_webui" |
| | if MILVUS_TOKEN is None: |
| | self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) |
| | else: |
| | self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN) |
| |
|
| | def _result_to_get_result(self, result) -> GetResult: |
| | ids = [] |
| | documents = [] |
| | metadatas = [] |
| | for match in result: |
| | _ids = [] |
| | _documents = [] |
| | _metadatas = [] |
| | for item in match: |
| | _ids.append(item.get("id")) |
| | _documents.append(item.get("data", {}).get("text")) |
| | _metadatas.append(item.get("metadata")) |
| | ids.append(_ids) |
| | documents.append(_documents) |
| | metadatas.append(_metadatas) |
| | return GetResult( |
| | **{ |
| | "ids": ids, |
| | "documents": documents, |
| | "metadatas": metadatas, |
| | } |
| | ) |
| |
|
| | def _result_to_search_result(self, result) -> SearchResult: |
| | ids = [] |
| | distances = [] |
| | documents = [] |
| | metadatas = [] |
| | for match in result: |
| | _ids = [] |
| | _distances = [] |
| | _documents = [] |
| | _metadatas = [] |
| | for item in match: |
| | _ids.append(item.get("id")) |
| | |
| | |
| | _dist = (item.get("distance") + 1.0) / 2.0 |
| | _distances.append(_dist) |
| | _documents.append(item.get("entity", {}).get("data", {}).get("text")) |
| | _metadatas.append(item.get("entity", {}).get("metadata")) |
| | ids.append(_ids) |
| | distances.append(_distances) |
| | documents.append(_documents) |
| | metadatas.append(_metadatas) |
| | return SearchResult( |
| | **{ |
| | "ids": ids, |
| | "distances": distances, |
| | "documents": documents, |
| | "metadatas": metadatas, |
| | } |
| | ) |
| |
|
| | def _create_collection(self, collection_name: str, dimension: int): |
| | schema = self.client.create_schema( |
| | auto_id=False, |
| | enable_dynamic_field=True, |
| | ) |
| | schema.add_field( |
| | field_name="id", |
| | datatype=DataType.VARCHAR, |
| | is_primary=True, |
| | max_length=65535, |
| | ) |
| | schema.add_field( |
| | field_name="vector", |
| | datatype=DataType.FLOAT_VECTOR, |
| | dim=dimension, |
| | description="vector", |
| | ) |
| | schema.add_field(field_name="data", datatype=DataType.JSON, description="data") |
| | schema.add_field( |
| | field_name="metadata", datatype=DataType.JSON, description="metadata" |
| | ) |
| |
|
| | index_params = self.client.prepare_index_params() |
| |
|
| | |
| | index_type = MILVUS_INDEX_TYPE.upper() |
| | metric_type = MILVUS_METRIC_TYPE.upper() |
| |
|
| | log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}") |
| |
|
| | index_creation_params = {} |
| | if index_type == "HNSW": |
| | index_creation_params = { |
| | "M": MILVUS_HNSW_M, |
| | "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, |
| | } |
| | log.info(f"HNSW params: {index_creation_params}") |
| | elif index_type == "IVF_FLAT": |
| | index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} |
| | log.info(f"IVF_FLAT params: {index_creation_params}") |
| | elif index_type == "DISKANN": |
| | index_creation_params = { |
| | "max_degree": MILVUS_DISKANN_MAX_DEGREE, |
| | "search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE, |
| | } |
| | log.info(f"DISKANN params: {index_creation_params}") |
| | elif index_type in ["FLAT", "AUTOINDEX"]: |
| | log.info(f"Using {index_type} index with no specific build-time params.") |
| | else: |
| | log.warning( |
| | f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " |
| | f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " |
| | f"Milvus will use its default for the collection if this type is not directly supported for index creation." |
| | ) |
| | |
| | |
| |
|
| | index_params.add_index( |
| | field_name="vector", |
| | index_type=index_type, |
| | metric_type=metric_type, |
| | params=index_creation_params, |
| | ) |
| |
|
| | self.client.create_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | schema=schema, |
| | index_params=index_params, |
| | ) |
| | log.info( |
| | f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'." |
| | ) |
| |
|
| | def has_collection(self, collection_name: str) -> bool: |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | return self.client.has_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}" |
| | ) |
| |
|
| | def delete_collection(self, collection_name: str): |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | return self.client.drop_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}" |
| | ) |
| |
|
| | def search( |
| | self, |
| | collection_name: str, |
| | vectors: list[list[float | int]], |
| | filter: Optional[dict] = None, |
| | limit: int = 10, |
| | ) -> Optional[SearchResult]: |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | |
| | |
| | |
| | result = self.client.search( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | data=vectors, |
| | limit=limit, |
| | output_fields=["data", "metadata"], |
| | |
| | ) |
| | return self._result_to_search_result(result) |
| |
|
| | def query(self, collection_name: str, filter: dict, limit: int = -1): |
| | connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) |
| |
|
| | collection_name = collection_name.replace("-", "_") |
| | if not self.has_collection(collection_name): |
| | log.warning( |
| | f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}" |
| | ) |
| | return None |
| |
|
| | filter_expressions = [] |
| | for key, value in filter.items(): |
| | if isinstance(value, str): |
| | filter_expressions.append(f'metadata["{key}"] == "{value}"') |
| | else: |
| | filter_expressions.append(f'metadata["{key}"] == {value}') |
| |
|
| | filter_string = " && ".join(filter_expressions) |
| |
|
| | collection = Collection(f"{self.collection_prefix}_{collection_name}") |
| | collection.load() |
| |
|
| | try: |
| | log.info( |
| | f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" |
| | ) |
| |
|
| | iterator = collection.query_iterator( |
| | expr=filter_string, |
| | output_fields=[ |
| | "id", |
| | "data", |
| | "metadata", |
| | ], |
| | limit=limit if limit > 0 else -1, |
| | ) |
| |
|
| | all_results = [] |
| | while True: |
| | batch = iterator.next() |
| | if not batch: |
| | iterator.close() |
| | break |
| | all_results.extend(batch) |
| |
|
| | log.debug(f"Total results from query: {len(all_results)}") |
| | return self._result_to_get_result([all_results] if all_results else [[]]) |
| |
|
| | except Exception as e: |
| | log.exception( |
| | f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" |
| | ) |
| | return None |
| |
|
| | def get(self, collection_name: str) -> Optional[GetResult]: |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | log.warning( |
| | f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections." |
| | ) |
| | |
| | |
| | return self.query(collection_name=collection_name, filter={}, limit=-1) |
| |
|
| | def insert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | if not self.client.has_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}" |
| | ): |
| | log.info( |
| | f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now." |
| | ) |
| | if not items: |
| | log.error( |
| | f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension." |
| | ) |
| | raise ValueError( |
| | "Cannot create Milvus collection without items to determine vector dimension." |
| | ) |
| | self._create_collection( |
| | collection_name=collection_name, dimension=len(items[0]["vector"]) |
| | ) |
| |
|
| | log.info( |
| | f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." |
| | ) |
| | return self.client.insert( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | data=[ |
| | { |
| | "id": item["id"], |
| | "vector": item["vector"], |
| | "data": {"text": item["text"]}, |
| | "metadata": process_metadata(item["metadata"]), |
| | } |
| | for item in items |
| | ], |
| | ) |
| |
|
| | def upsert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | if not self.client.has_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}" |
| | ): |
| | log.info( |
| | f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now." |
| | ) |
| | if not items: |
| | log.error( |
| | f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension." |
| | ) |
| | raise ValueError( |
| | "Cannot create Milvus collection for upsert without items to determine vector dimension." |
| | ) |
| | self._create_collection( |
| | collection_name=collection_name, dimension=len(items[0]["vector"]) |
| | ) |
| |
|
| | log.info( |
| | f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." |
| | ) |
| | return self.client.upsert( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | data=[ |
| | { |
| | "id": item["id"], |
| | "vector": item["vector"], |
| | "data": {"text": item["text"]}, |
| | "metadata": process_metadata(item["metadata"]), |
| | } |
| | for item in items |
| | ], |
| | ) |
| |
|
| | def delete( |
| | self, |
| | collection_name: str, |
| | ids: Optional[list[str]] = None, |
| | filter: Optional[dict] = None, |
| | ): |
| | |
| | collection_name = collection_name.replace("-", "_") |
| | if not self.has_collection(collection_name): |
| | log.warning( |
| | f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}" |
| | ) |
| | return None |
| |
|
| | if ids: |
| | log.info( |
| | f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}" |
| | ) |
| | return self.client.delete( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | ids=ids, |
| | ) |
| | elif filter: |
| | filter_string = " && ".join( |
| | [ |
| | f'metadata["{key}"] == {json.dumps(value)}' |
| | for key, value in filter.items() |
| | ] |
| | ) |
| | log.info( |
| | f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}" |
| | ) |
| | return self.client.delete( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | filter=filter_string, |
| | ) |
| | else: |
| | log.warning( |
| | f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken." |
| | ) |
| | return None |
| |
|
| | def reset(self): |
| | |
| | log.warning( |
| | f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'." |
| | ) |
| | collection_names = self.client.list_collections() |
| | deleted_collections = [] |
| | for collection_name_full in collection_names: |
| | if collection_name_full.startswith(self.collection_prefix): |
| | try: |
| | self.client.drop_collection(collection_name=collection_name_full) |
| | deleted_collections.append(collection_name_full) |
| | log.info(f"Deleted collection: {collection_name_full}") |
| | except Exception as e: |
| | log.error(f"Error deleting collection {collection_name_full}: {e}") |
| | log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}") |
| |
|