Spaces:
Paused
Paused
| import chromadb | |
| from chromadb import Settings | |
| from chromadb.utils.batch_utils import create_batches | |
| from typing import Optional | |
| from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult | |
| from open_webui.config import ( | |
| CHROMA_DATA_PATH, | |
| CHROMA_HTTP_HOST, | |
| CHROMA_HTTP_PORT, | |
| CHROMA_HTTP_HEADERS, | |
| CHROMA_HTTP_SSL, | |
| CHROMA_TENANT, | |
| CHROMA_DATABASE, | |
| ) | |
| class ChromaClient: | |
| def __init__(self): | |
| if CHROMA_HTTP_HOST != "": | |
| self.client = chromadb.HttpClient( | |
| host=CHROMA_HTTP_HOST, | |
| port=CHROMA_HTTP_PORT, | |
| headers=CHROMA_HTTP_HEADERS, | |
| ssl=CHROMA_HTTP_SSL, | |
| tenant=CHROMA_TENANT, | |
| database=CHROMA_DATABASE, | |
| settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
| ) | |
| else: | |
| self.client = chromadb.PersistentClient( | |
| path=CHROMA_DATA_PATH, | |
| settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
| tenant=CHROMA_TENANT, | |
| database=CHROMA_DATABASE, | |
| ) | |
| def has_collection(self, collection_name: str) -> bool: | |
| # Check if the collection exists based on the collection name. | |
| collections = self.client.list_collections() | |
| return collection_name in [collection.name for collection in collections] | |
| def delete_collection(self, collection_name: str): | |
| # Delete the collection based on the collection name. | |
| return self.client.delete_collection(name=collection_name) | |
| def search( | |
| self, collection_name: str, vectors: list[list[float | int]], limit: int | |
| ) -> Optional[SearchResult]: | |
| # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. | |
| collection = self.client.get_collection(name=collection_name) | |
| if collection: | |
| result = collection.query( | |
| query_embeddings=vectors, | |
| n_results=limit, | |
| ) | |
| return SearchResult( | |
| **{ | |
| "ids": result["ids"], | |
| "distances": result["distances"], | |
| "documents": result["documents"], | |
| "metadatas": result["metadatas"], | |
| } | |
| ) | |
| return None | |
| def get(self, collection_name: str) -> Optional[GetResult]: | |
| # Get all the items in the collection. | |
| collection = self.client.get_collection(name=collection_name) | |
| if collection: | |
| result = collection.get() | |
| return GetResult( | |
| **{ | |
| "ids": [result["ids"]], | |
| "documents": [result["documents"]], | |
| "metadatas": [result["metadatas"]], | |
| } | |
| ) | |
| return None | |
| def insert(self, collection_name: str, items: list[VectorItem]): | |
| # Insert the items into the collection, if the collection does not exist, it will be created. | |
| collection = self.client.get_or_create_collection(name=collection_name) | |
| ids = [item["id"] for item in items] | |
| documents = [item["text"] for item in items] | |
| embeddings = [item["vector"] for item in items] | |
| metadatas = [item["metadata"] for item in items] | |
| for batch in create_batches( | |
| api=self.client, | |
| documents=documents, | |
| embeddings=embeddings, | |
| ids=ids, | |
| metadatas=metadatas, | |
| ): | |
| collection.add(*batch) | |
| def upsert(self, collection_name: str, items: list[VectorItem]): | |
| # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. | |
| collection = self.client.get_or_create_collection(name=collection_name) | |
| ids = [item["id"] for item in items] | |
| documents = [item["text"] for item in items] | |
| embeddings = [item["vector"] for item in items] | |
| metadatas = [item["metadata"] for item in items] | |
| collection.upsert( | |
| ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas | |
| ) | |
| def delete(self, collection_name: str, ids: list[str]): | |
| # Delete the items from the collection based on the ids. | |
| collection = self.client.get_collection(name=collection_name) | |
| if collection: | |
| collection.delete(ids=ids) | |
| def reset(self): | |
| # Resets the database. This will delete all collections and item entries. | |
| return self.client.reset() | |