| from pymilvus import MilvusClient as Client |
| from pymilvus import FieldSchema, DataType |
| import json |
| import logging |
| from typing import Optional |
|
|
| from open_webui.retrieval.vector.main import ( |
| VectorDBBase, |
| VectorItem, |
| SearchResult, |
| GetResult, |
| ) |
| from open_webui.config import ( |
| MILVUS_URI, |
| MILVUS_DB, |
| MILVUS_TOKEN, |
| ) |
| from open_webui.env import SRC_LOG_LEVELS |
|
|
| log = logging.getLogger(__name__) |
| log.setLevel(SRC_LOG_LEVELS["RAG"]) |
|
|
|
|
| class MilvusClient(VectorDBBase): |
| def __init__(self): |
| self.collection_prefix = "open_webui" |
| if MILVUS_TOKEN is None: |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) |
| else: |
| self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN) |
|
|
| def _result_to_get_result(self, result) -> GetResult: |
| ids = [] |
| documents = [] |
| metadatas = [] |
|
|
| for match in result: |
| _ids = [] |
| _documents = [] |
| _metadatas = [] |
| for item in match: |
| _ids.append(item.get("id")) |
| _documents.append(item.get("data", {}).get("text")) |
| _metadatas.append(item.get("metadata")) |
|
|
| ids.append(_ids) |
| documents.append(_documents) |
| metadatas.append(_metadatas) |
|
|
| return GetResult( |
| **{ |
| "ids": ids, |
| "documents": documents, |
| "metadatas": metadatas, |
| } |
| ) |
|
|
| def _result_to_search_result(self, result) -> SearchResult: |
| ids = [] |
| distances = [] |
| documents = [] |
| metadatas = [] |
|
|
| for match in result: |
| _ids = [] |
| _distances = [] |
| _documents = [] |
| _metadatas = [] |
|
|
| for item in match: |
| _ids.append(item.get("id")) |
| |
| |
| _dist = (item.get("distance") + 1.0) / 2.0 |
| _distances.append(_dist) |
| _documents.append(item.get("entity", {}).get("data", {}).get("text")) |
| _metadatas.append(item.get("entity", {}).get("metadata")) |
|
|
| ids.append(_ids) |
| distances.append(_distances) |
| documents.append(_documents) |
| metadatas.append(_metadatas) |
|
|
| return SearchResult( |
| **{ |
| "ids": ids, |
| "distances": distances, |
| "documents": documents, |
| "metadatas": metadatas, |
| } |
| ) |
|
|
| def _create_collection(self, collection_name: str, dimension: int): |
| schema = self.client.create_schema( |
| auto_id=False, |
| enable_dynamic_field=True, |
| ) |
| schema.add_field( |
| field_name="id", |
| datatype=DataType.VARCHAR, |
| is_primary=True, |
| max_length=65535, |
| ) |
| schema.add_field( |
| field_name="vector", |
| datatype=DataType.FLOAT_VECTOR, |
| dim=dimension, |
| description="vector", |
| ) |
| schema.add_field(field_name="data", datatype=DataType.JSON, description="data") |
| schema.add_field( |
| field_name="metadata", datatype=DataType.JSON, description="metadata" |
| ) |
|
|
| index_params = self.client.prepare_index_params() |
| index_params.add_index( |
| field_name="vector", |
| index_type="HNSW", |
| metric_type="COSINE", |
| params={"M": 16, "efConstruction": 100}, |
| ) |
|
|
| self.client.create_collection( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| schema=schema, |
| index_params=index_params, |
| ) |
|
|
| def has_collection(self, collection_name: str) -> bool: |
| |
| collection_name = collection_name.replace("-", "_") |
| return self.client.has_collection( |
| collection_name=f"{self.collection_prefix}_{collection_name}" |
| ) |
|
|
| def delete_collection(self, collection_name: str): |
| |
| collection_name = collection_name.replace("-", "_") |
| return self.client.drop_collection( |
| collection_name=f"{self.collection_prefix}_{collection_name}" |
| ) |
|
|
| def search( |
| self, collection_name: str, vectors: list[list[float | int]], limit: int |
| ) -> Optional[SearchResult]: |
| |
| collection_name = collection_name.replace("-", "_") |
| result = self.client.search( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| data=vectors, |
| limit=limit, |
| output_fields=["data", "metadata"], |
| ) |
|
|
| return self._result_to_search_result(result) |
|
|
| def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): |
| |
| collection_name = collection_name.replace("-", "_") |
| if not self.has_collection(collection_name): |
| return None |
|
|
| filter_string = " && ".join( |
| [ |
| f'metadata["{key}"] == {json.dumps(value)}' |
| for key, value in filter.items() |
| ] |
| ) |
|
|
| max_limit = 16383 |
| all_results = [] |
|
|
| if limit is None: |
| limit = float("inf") |
|
|
| |
| offset = 0 |
| remaining = limit |
|
|
| try: |
| |
| while remaining > 0: |
| log.info(f"remaining: {remaining}") |
| current_fetch = min( |
| max_limit, remaining |
| ) |
|
|
| results = self.client.query( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| filter=filter_string, |
| output_fields=["*"], |
| limit=current_fetch, |
| offset=offset, |
| ) |
|
|
| if not results: |
| break |
|
|
| all_results.extend(results) |
| results_count = len(results) |
| remaining -= ( |
| results_count |
| ) |
| offset += results_count |
|
|
| |
| if results_count < current_fetch: |
| break |
|
|
| log.debug(all_results) |
| return self._result_to_get_result([all_results]) |
| except Exception as e: |
| log.exception( |
| f"Error querying collection {collection_name} with limit {limit}: {e}" |
| ) |
| return None |
|
|
| def get(self, collection_name: str) -> Optional[GetResult]: |
| |
| collection_name = collection_name.replace("-", "_") |
| result = self.client.query( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| filter='id != ""', |
| ) |
| return self._result_to_get_result([result]) |
|
|
| def insert(self, collection_name: str, items: list[VectorItem]): |
| |
| collection_name = collection_name.replace("-", "_") |
| if not self.client.has_collection( |
| collection_name=f"{self.collection_prefix}_{collection_name}" |
| ): |
| self._create_collection( |
| collection_name=collection_name, dimension=len(items[0]["vector"]) |
| ) |
|
|
| return self.client.insert( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| data=[ |
| { |
| "id": item["id"], |
| "vector": item["vector"], |
| "data": {"text": item["text"]}, |
| "metadata": item["metadata"], |
| } |
| for item in items |
| ], |
| ) |
|
|
| def upsert(self, collection_name: str, items: list[VectorItem]): |
| |
| collection_name = collection_name.replace("-", "_") |
| if not self.client.has_collection( |
| collection_name=f"{self.collection_prefix}_{collection_name}" |
| ): |
| self._create_collection( |
| collection_name=collection_name, dimension=len(items[0]["vector"]) |
| ) |
|
|
| return self.client.upsert( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| data=[ |
| { |
| "id": item["id"], |
| "vector": item["vector"], |
| "data": {"text": item["text"]}, |
| "metadata": item["metadata"], |
| } |
| for item in items |
| ], |
| ) |
|
|
| def delete( |
| self, |
| collection_name: str, |
| ids: Optional[list[str]] = None, |
| filter: Optional[dict] = None, |
| ): |
| |
| collection_name = collection_name.replace("-", "_") |
| if ids: |
| return self.client.delete( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| ids=ids, |
| ) |
| elif filter: |
| |
| filter_string = " && ".join( |
| [ |
| f'metadata["{key}"] == {json.dumps(value)}' |
| for key, value in filter.items() |
| ] |
| ) |
|
|
| return self.client.delete( |
| collection_name=f"{self.collection_prefix}_{collection_name}", |
| filter=filter_string, |
| ) |
|
|
| def reset(self): |
| |
| collection_names = self.client.list_collections() |
| for collection_name in collection_names: |
| if collection_name.startswith(self.collection_prefix): |
| self.client.drop_collection(collection_name=collection_name) |
|
|