| | from typing import Optional |
| |
|
| | from qdrant_client import QdrantClient as Qclient |
| | from qdrant_client.http.models import PointStruct |
| | from qdrant_client.models import models |
| |
|
| | from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult |
| | from open_webui.config import QDRANT_URI, QDRANT_API_KEY |
| |
|
| | NO_LIMIT = 999999999 |
| |
|
| |
|
| | class QdrantClient: |
| | def __init__(self): |
| | self.collection_prefix = "open-webui" |
| | self.QDRANT_URI = QDRANT_URI |
| | self.QDRANT_API_KEY = QDRANT_API_KEY |
| | self.client = ( |
| | Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) |
| | if self.QDRANT_URI |
| | else None |
| | ) |
| |
|
| | def _result_to_get_result(self, points) -> GetResult: |
| | ids = [] |
| | documents = [] |
| | metadatas = [] |
| |
|
| | for point in points: |
| | payload = point.payload |
| | ids.append(point.id) |
| | documents.append(payload["text"]) |
| | metadatas.append(payload["metadata"]) |
| |
|
| | return GetResult( |
| | **{ |
| | "ids": [ids], |
| | "documents": [documents], |
| | "metadatas": [metadatas], |
| | } |
| | ) |
| |
|
| | def _create_collection(self, collection_name: str, dimension: int): |
| | collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" |
| | self.client.create_collection( |
| | collection_name=collection_name_with_prefix, |
| | vectors_config=models.VectorParams( |
| | size=dimension, distance=models.Distance.COSINE |
| | ), |
| | ) |
| |
|
| | print(f"collection {collection_name_with_prefix} successfully created!") |
| |
|
| | def _create_collection_if_not_exists(self, collection_name, dimension): |
| | if not self.has_collection(collection_name=collection_name): |
| | self._create_collection( |
| | collection_name=collection_name, dimension=dimension |
| | ) |
| |
|
| | def _create_points(self, items: list[VectorItem]): |
| | return [ |
| | PointStruct( |
| | id=item["id"], |
| | vector=item["vector"], |
| | payload={"text": item["text"], "metadata": item["metadata"]}, |
| | ) |
| | for item in items |
| | ] |
| |
|
| | def has_collection(self, collection_name: str) -> bool: |
| | return self.client.collection_exists( |
| | f"{self.collection_prefix}_{collection_name}" |
| | ) |
| |
|
| | def delete_collection(self, collection_name: str): |
| | return self.client.delete_collection( |
| | collection_name=f"{self.collection_prefix}_{collection_name}" |
| | ) |
| |
|
| | def search( |
| | self, collection_name: str, vectors: list[list[float | int]], limit: int |
| | ) -> Optional[SearchResult]: |
| | |
| | if limit is None: |
| | limit = NO_LIMIT |
| |
|
| | query_response = self.client.query_points( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | query=vectors[0], |
| | limit=limit, |
| | ) |
| | get_result = self._result_to_get_result(query_response.points) |
| | return SearchResult( |
| | ids=get_result.ids, |
| | documents=get_result.documents, |
| | metadatas=get_result.metadatas, |
| | distances=[[point.score for point in query_response.points]], |
| | ) |
| |
|
| | def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): |
| | |
| | if not self.has_collection(collection_name): |
| | return None |
| | try: |
| | if limit is None: |
| | limit = NO_LIMIT |
| |
|
| | field_conditions = [] |
| | for key, value in filter.items(): |
| | field_conditions.append( |
| | models.FieldCondition( |
| | key=f"metadata.{key}", match=models.MatchValue(value=value) |
| | ) |
| | ) |
| |
|
| | points = self.client.query_points( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | query_filter=models.Filter(should=field_conditions), |
| | limit=limit, |
| | ) |
| | return self._result_to_get_result(points.points) |
| | except Exception as e: |
| | print(e) |
| | return None |
| |
|
| | def get(self, collection_name: str) -> Optional[GetResult]: |
| | |
| | points = self.client.query_points( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | limit=NO_LIMIT, |
| | ) |
| | return self._result_to_get_result(points.points) |
| |
|
| | def insert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) |
| | points = self._create_points(items) |
| | self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points) |
| |
|
| | def upsert(self, collection_name: str, items: list[VectorItem]): |
| | |
| | self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) |
| | points = self._create_points(items) |
| | return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) |
| |
|
| | def delete( |
| | self, |
| | collection_name: str, |
| | ids: Optional[list[str]] = None, |
| | filter: Optional[dict] = None, |
| | ): |
| | |
| | field_conditions = [] |
| |
|
| | if ids: |
| | for id_value in ids: |
| | field_conditions.append( |
| | models.FieldCondition( |
| | key="metadata.id", |
| | match=models.MatchValue(value=id_value), |
| | ), |
| | ), |
| | elif filter: |
| | for key, value in filter.items(): |
| | field_conditions.append( |
| | models.FieldCondition( |
| | key=f"metadata.{key}", |
| | match=models.MatchValue(value=value), |
| | ), |
| | ), |
| |
|
| | return self.client.delete( |
| | collection_name=f"{self.collection_prefix}_{collection_name}", |
| | points_selector=models.FilterSelector( |
| | filter=models.Filter(must=field_conditions) |
| | ), |
| | ) |
| |
|
| | def reset(self): |
| | |
| | collection_names = self.client.get_collections().collections |
| | for collection_name in collection_names: |
| | if collection_name.name.startswith(self.collection_prefix): |
| | self.client.delete_collection(collection_name=collection_name.name) |
| |
|