| import os |
| import chromadb |
| from chromadb.utils import embedding_functions |
|
|
| class ChromaStorage: |
| def __init__(self, db_path: str = './src/chroma_db', collection_name: str = 'my_collection'): |
| |
| self.client = chromadb.PersistentClient(path=db_path) |
| self.collection = self._get_or_create_collection(collection_name) |
|
|
| def _get_or_create_collection(self, name: str): |
| try: |
| return self.client.get_collection( |
| name=name |
| ) |
| except Exception: |
| return self.client.create_collection( |
| name=name, |
| embedding_function=self.openai_ef |
| ) |
|
|
| def add_batch(self, documents: list, metadatas: list, ids: list, batch_size: int = 200): |
| for i in range(0, len(documents), batch_size): |
| docs = documents[i:i + batch_size] |
| metas = metadatas[i:i + batch_size] |
| batch_ids = ids[i:i + batch_size] |
| self.collection.add(documents=docs, metadatas=metas, ids=batch_ids) |
|
|
| def query(self, query_text: str, k: int = 1) -> list: |
| return self.collection.query(query_texts=[query_text], n_results=k) |
|
|
| def delete_all(self): |
| self.client.reset() |
|
|