chatbot / src /chroma_storage.py
samiha123's picture
first commit
dc22afe
import os
import chromadb
from chromadb.utils import embedding_functions
class ChromaStorage:
def __init__(self, db_path: str = './src/chroma_db', collection_name: str = 'my_collection'):
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self._get_or_create_collection(collection_name)
def _get_or_create_collection(self, name: str):
try:
return self.client.get_collection(
name=name
)
except Exception:
return self.client.create_collection(
name=name,
embedding_function=self.openai_ef
)
def add_batch(self, documents: list, metadatas: list, ids: list, batch_size: int = 200):
for i in range(0, len(documents), batch_size):
docs = documents[i:i + batch_size]
metas = metadatas[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
self.collection.add(documents=docs, metadatas=metas, ids=batch_ids)
def query(self, query_text: str, k: int = 1) -> list:
return self.collection.query(query_texts=[query_text], n_results=k)
def delete_all(self):
self.client.reset()