| import chromadb |
| from chromadb import Settings |
| from chromadb.utils.batch_utils import create_batches |
|
|
| from typing import Optional |
|
|
| from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult |
| from open_webui.config import ( |
| CHROMA_DATA_PATH, |
| CHROMA_HTTP_HOST, |
| CHROMA_HTTP_PORT, |
| CHROMA_HTTP_HEADERS, |
| CHROMA_HTTP_SSL, |
| CHROMA_TENANT, |
| CHROMA_DATABASE, |
| CHROMA_CLIENT_AUTH_PROVIDER, |
| CHROMA_CLIENT_AUTH_CREDENTIALS, |
| ) |
|
|
|
|
| class ChromaClient: |
| def __init__(self): |
| settings_dict = { |
| "allow_reset": True, |
| "anonymized_telemetry": False, |
| } |
| if CHROMA_CLIENT_AUTH_PROVIDER is not None: |
| settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER |
| if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: |
| settings_dict["chroma_client_auth_credentials"] = ( |
| CHROMA_CLIENT_AUTH_CREDENTIALS |
| ) |
|
|
| if CHROMA_HTTP_HOST != "": |
| self.client = chromadb.HttpClient( |
| host=CHROMA_HTTP_HOST, |
| port=CHROMA_HTTP_PORT, |
| headers=CHROMA_HTTP_HEADERS, |
| ssl=CHROMA_HTTP_SSL, |
| tenant=CHROMA_TENANT, |
| database=CHROMA_DATABASE, |
| settings=Settings(**settings_dict), |
| ) |
| else: |
| self.client = chromadb.PersistentClient( |
| path=CHROMA_DATA_PATH, |
| settings=Settings(**settings_dict), |
| tenant=CHROMA_TENANT, |
| database=CHROMA_DATABASE, |
| ) |
|
|
| def has_collection(self, collection_name: str) -> bool: |
| |
| collection_names = self.client.list_collections() |
| return collection_name in collection_names |
|
|
| def delete_collection(self, collection_name: str): |
| |
| return self.client.delete_collection(name=collection_name) |
|
|
| def search( |
| self, collection_name: str, vectors: list[list[float | int]], limit: int |
| ) -> Optional[SearchResult]: |
| |
| try: |
| collection = self.client.get_collection(name=collection_name) |
| if collection: |
| result = collection.query( |
| query_embeddings=vectors, |
| n_results=limit, |
| ) |
|
|
| return SearchResult( |
| **{ |
| "ids": result["ids"], |
| "distances": result["distances"], |
| "documents": result["documents"], |
| "metadatas": result["metadatas"], |
| } |
| ) |
| return None |
| except Exception as e: |
| return None |
|
|
| def query( |
| self, collection_name: str, filter: dict, limit: Optional[int] = None |
| ) -> Optional[GetResult]: |
| |
| try: |
| collection = self.client.get_collection(name=collection_name) |
| if collection: |
| result = collection.get( |
| where=filter, |
| limit=limit, |
| ) |
|
|
| return GetResult( |
| **{ |
| "ids": [result["ids"]], |
| "documents": [result["documents"]], |
| "metadatas": [result["metadatas"]], |
| } |
| ) |
| return None |
| except Exception as e: |
| print(e) |
| return None |
|
|
| def get(self, collection_name: str) -> Optional[GetResult]: |
| |
| collection = self.client.get_collection(name=collection_name) |
| if collection: |
| result = collection.get() |
| return GetResult( |
| **{ |
| "ids": [result["ids"]], |
| "documents": [result["documents"]], |
| "metadatas": [result["metadatas"]], |
| } |
| ) |
| return None |
|
|
| def insert(self, collection_name: str, items: list[VectorItem]): |
| |
| collection = self.client.get_or_create_collection( |
| name=collection_name, metadata={"hnsw:space": "cosine"} |
| ) |
|
|
| ids = [item["id"] for item in items] |
| documents = [item["text"] for item in items] |
| embeddings = [item["vector"] for item in items] |
| metadatas = [item["metadata"] for item in items] |
|
|
| for batch in create_batches( |
| api=self.client, |
| documents=documents, |
| embeddings=embeddings, |
| ids=ids, |
| metadatas=metadatas, |
| ): |
| collection.add(*batch) |
|
|
| def upsert(self, collection_name: str, items: list[VectorItem]): |
| |
| collection = self.client.get_or_create_collection( |
| name=collection_name, metadata={"hnsw:space": "cosine"} |
| ) |
|
|
| ids = [item["id"] for item in items] |
| documents = [item["text"] for item in items] |
| embeddings = [item["vector"] for item in items] |
| metadatas = [item["metadata"] for item in items] |
|
|
| collection.upsert( |
| ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas |
| ) |
|
|
| def delete( |
| self, |
| collection_name: str, |
| ids: Optional[list[str]] = None, |
| filter: Optional[dict] = None, |
| ): |
| |
| collection = self.client.get_collection(name=collection_name) |
| if collection: |
| if ids: |
| collection.delete(ids=ids) |
| elif filter: |
| collection.delete(where=filter) |
|
|
| def reset(self): |
| |
| return self.client.reset() |
|
|