from functools import lru_cache import chromadb from chromadb.api.models.Collection import Collection from src.config import settings @lru_cache(maxsize=1) def get_chroma_client() -> chromadb.ClientAPI: return chromadb.PersistentClient(path=str(settings.chroma_db_path)) def get_collection() -> Collection: client = get_chroma_client() return client.get_or_create_collection( name=settings.chroma_collection_name, metadata={"hnsw:space": "cosine"}, ) def add_chunks( ids: list[str], documents: list[str], embeddings: list[list[float]], metadatas: list[dict], ) -> None: collection = get_collection() # ChromaDB has a batch limit; process in batches of 5000 batch_size = 5000 for i in range(0, len(ids), batch_size): collection.add( ids=ids[i : i + batch_size], documents=documents[i : i + batch_size], embeddings=embeddings[i : i + batch_size], metadatas=metadatas[i : i + batch_size], ) def query_chunks( query_embedding: list[float], n_results: int = 5, where: dict | None = None, ) -> dict: collection = get_collection() kwargs = { "query_embeddings": [query_embedding], "n_results": n_results, "include": ["documents", "metadatas", "distances"], } if where: kwargs["where"] = where return collection.query(**kwargs) def delete_by_book_id(book_id: str) -> None: collection = get_collection() collection.delete(where={"book_id": book_id}) def get_collection_count() -> int: collection = get_collection() return collection.count()