Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, Optional, List | |
| import chromadb | |
| from chromadb.config import Settings | |
| from chromadb.api.types import ( | |
| Where, | |
| GetResult, | |
| QueryResult, | |
| ) | |
| from ..embedding_provider import EmbeddingProvider | |
| from .vector_store import VectorStore | |
| class ChromaDB(VectorStore): | |
| """ | |
| ChromaDB is an example of a vector-store class implementation. | |
| See more: | |
| https://github.com/chroma-core/chroma | |
| """ | |
| def __init__( | |
| self, | |
| configs: Dict[str, Any] = {}, | |
| db_path: str = ".chromadb", | |
| embedding_function: Optional[EmbeddingProvider] = None, | |
| collection_name: Optional[str] = None, | |
| ) -> None: | |
| self.client = chromadb.PersistentClient( | |
| path=db_path | |
| ) | |
| self.configs = configs | |
| self.embedding_function = embedding_function | |
| self._collection_name = collection_name | |
| self.collection = self.client.get_or_create_collection( | |
| name = self.collection_name or "default_collection" | |
| ) | |
| # self.logger = get_logger(self.__class__.__name__) | |
| def db_path(self) -> str: | |
| return self.client.get_settings().persist_directory | |
| def db_path(self, value: str) -> None: | |
| self.client = chromadb.PersistentClient(path=value) | |
| self.collection = self.client.get_or_create_collection( | |
| name = self.collection_name or "default_collection" | |
| ) | |
| def collection_name(self): | |
| return self._collection_name | |
| def collection_name(self, value): | |
| self._collection_name = value | |
| self.collection.modify(name=value) | |
| def add_data( | |
| self, | |
| documents: List[str], | |
| ids: List[str], | |
| metadatas: Optional[List[Dict[str, Any]]] = None, | |
| **optional_kwargs | |
| ) -> None: | |
| """ | |
| Add data to the collection by creating embeddings for them. | |
| Args: | |
| documents (List[str]): List of documents to add. | |
| ids (List[str]): List of ids for the documents. | |
| metadatas (Optional[List[Dict[str, Any]]]): List of metadata for the documents. | |
| **optional_kwargs: Additional keyword arguments (see collection.add for more). | |
| """ | |
| try: | |
| params = { | |
| "documents": documents, | |
| "ids": ids, | |
| **optional_kwargs | |
| } | |
| params["metadatas"] = metadatas or None | |
| # If an embedding function is provided, create embeddings for the documents | |
| if self.embedding_function: | |
| embeddings = self.embedding_function.embed_documents(documents) | |
| params["embeddings"] = embeddings | |
| self.collection.add(**params) | |
| except Exception as e: | |
| # self.logger.error(f"Error adding data to collection: {e}") | |
| print(f"Error adding data to collection: {e}") | |
| raise e | |
| def search( | |
| self, | |
| query_text: Optional[List[str]] = None, | |
| query_embedding: Optional[List[List[float]]] = None, | |
| n_results: int = 10, | |
| **optional_kwargs | |
| ) -> QueryResult: | |
| """ | |
| Query the collection for similar documents. | |
| Args: | |
| query_text (Optional[List[str]]): List of query texts. | |
| query_embedding (Optional[List[List[float]]]): List of query embeddings. | |
| n_results (int): Number of results to return. | |
| **optional_kwargs: Additional keyword arguments (see collection.query for more). | |
| Returns: | |
| QueryResult: The result of the query. | |
| """ | |
| try: | |
| if query_text is None and query_embedding is None: | |
| raise ValueError("Either query_text or query_embedding must be provided.") | |
| params = { | |
| "n_results": n_results, | |
| **optional_kwargs | |
| } | |
| if query_text and query_embedding is None: | |
| if self.embedding_function: | |
| query_embedding = self.embedding_function.embed_query(query_text) | |
| params["query_embeddings"] = query_embedding | |
| else: | |
| params["query_text"] = query_text | |
| elif query_embedding and query_text is None: | |
| params["query_embeddings"] = query_embedding | |
| elif query_embedding and query_text: | |
| params["query_embeddings"] = query_embedding | |
| if self.embedding_function: | |
| embeddings = self.embedding_function.embed_query(query_text) | |
| params["query_embeddings"] = query_embedding.extend(embeddings) | |
| else: | |
| params["query_text"] = query_text | |
| return self.collection.query(**params) | |
| except Exception as e: | |
| # self.logger.error(f"Error querying data from collection: {e}") | |
| print(f"Error querying data from collection: {e}") | |
| raise e | |
| def query_by_id_or_metadata( | |
| self, | |
| ids: Optional[List[str]] = None, | |
| where: Optional[Where] = None, | |
| n_results: int = 10, | |
| **optional_kwargs | |
| ) -> GetResult: | |
| """ | |
| Query the collection for similar documents. | |
| Args: | |
| ids (Optional[List[str]]): List of ids to query. | |
| where (Optional[Where]): Where clause to query. | |
| n_results (int): Number of results to return. | |
| **optional_kwargs: Additional keyword arguments (see collection.get for more). | |
| Returns: | |
| GetResult: The result of the query. | |
| """ | |
| try: | |
| if ids is None and where is None: | |
| raise ValueError("Either ids or where must be provided.") | |
| params = { | |
| "n_results": n_results, | |
| **optional_kwargs | |
| } | |
| if ids: | |
| params["ids"] = ids | |
| if where: | |
| params["where"] = where | |
| return self.collection.get(**params) | |
| except Exception as e: | |
| # self.logger.error(f"Error querying data from collection: {e}") | |
| print(f"Error querying data from collection: {e}") | |
| raise e | |
| def delete_collection(self, collection_name: Optional[str] = None) -> None: | |
| """ | |
| Delete a specific collection from the ChromaDB. | |
| Args: | |
| collection_name (Optional[str]): Name of collection to delete. | |
| Uses class's collection_name if not provided. | |
| """ | |
| try: | |
| target_collection = collection_name or self.collection_name | |
| if not target_collection: | |
| raise ValueError("No collection name provided") | |
| self.client.delete_collection(name=target_collection) | |
| print(f"Collection '{target_collection}' deleted successfully.") | |
| except Exception as e: | |
| print(f"Error deleting collection: {e}") | |