""" ChromaDB vector storage interface. This module provides a clean interface to ChromaDB for storing and retrieving document chunks with their embeddings and metadata. """ import chromadb from typing import List, Optional import numpy as np import json from datetime import datetime from src.config.settings import get_settings, get_collection_name_for_model, EMBEDDING_MODELS from src.utils.logging import get_logger from src.ingestion.models import Chunk logger = get_logger(__name__) class VectorStore: """ChromaDB interface for vector storage.""" def __init__(self, embedding_model: Optional[str] = None): """ Initialize vector store with settings from configuration. Args: embedding_model: Optional embedding model ID. If provided, uses model-specific collection. """ settings = get_settings() self.persist_dir = settings.chroma_persist_dir self._base_collection_name = settings.chroma_collection_name self._embedding_model = embedding_model or settings.embedding_model # Use model-specific collection name self.collection_name = get_collection_name_for_model( self._embedding_model, self._base_collection_name ) self._client = None self._collection = None @property def client(self): """ Lazy initialize ChromaDB client. Returns: chromadb.Client: ChromaDB client instance """ if self._client is None: logger.info(f"Initializing ChromaDB client: {self.persist_dir}") self._client = chromadb.PersistentClient(path=self.persist_dir) logger.debug(f"ChromaDB client initialized") return self._client def get_collection(self): """ Get or create the collection. Returns: chromadb.Collection: Collection instance """ if self._collection is None: self._collection = self.client.get_or_create_collection( name=self.collection_name, metadata={"description": "Hierarchical PDF chunks with embeddings"} ) logger.info(f"Collection loaded: {self.collection_name}") return self._collection def add_chunks(self, chunks: List[Chunk], embeddings: np.ndarray): """ Add chunks with embeddings to ChromaDB. Args: chunks: List of chunks to store embeddings: Numpy array of embeddings (num_chunks x embedding_dim) """ if len(chunks) != len(embeddings): raise ValueError(f"Number of chunks ({len(chunks)}) != number of embeddings ({len(embeddings)})") collection = self.get_collection() # Prepare data for ChromaDB ids = [str(chunk.chunk_id) for chunk in chunks] documents = [chunk.text for chunk in chunks] metadatas = [self._prepare_metadata(chunk) for chunk in chunks] logger.info(f"Adding {len(chunks)} chunks to ChromaDB") # Add to collection collection.add( ids=ids, embeddings=embeddings.tolist(), documents=documents, metadatas=metadatas ) logger.info(f"Successfully added {len(chunks)} chunks") def _prepare_metadata(self, chunk: Chunk) -> dict: """ Prepare metadata for ChromaDB storage. ChromaDB metadata can only contain: str, int, float, bool. Lists must be JSON-encoded. Args: chunk: Chunk to extract metadata from Returns: dict: Metadata dictionary """ return { "chunk_id": str(chunk.chunk_id), "document_id": str(chunk.document_id), "parent_id": str(chunk.parent_id) if chunk.parent_id else "", "chunk_type": chunk.chunk_type, "token_count": chunk.token_count, "chunk_index": chunk.chunk_index, "page_numbers": json.dumps(chunk.page_numbers), "start_char": chunk.start_char, "end_char": chunk.end_char, "file_hash": chunk.file_hash, "filename": chunk.filename, } def document_exists(self, file_hash: str) -> bool: """ Check if document with given hash already exists. Args: file_hash: SHA256 hash of document Returns: bool: True if document exists """ collection = self.get_collection() try: # Try to query for any chunk with this file hash results = collection.get( where={"file_hash": file_hash}, limit=1 ) exists = len(results['ids']) > 0 if exists: logger.debug(f"Document with hash {file_hash[:8]}... already exists") return exists except Exception as e: # If metadata field doesn't exist, document doesn't exist logger.debug(f"Document check failed: {e}") return False def get_chunk(self, chunk_id: str) -> Optional[dict]: """ Retrieve a specific chunk by ID. Args: chunk_id: UUID of chunk to retrieve Returns: Optional[dict]: Chunk data or None if not found """ collection = self.get_collection() try: results = collection.get( ids=[chunk_id], include=["documents", "metadatas", "embeddings"] ) if len(results['ids']) > 0: return { "id": results['ids'][0], "document": results['documents'][0], "metadata": results['metadatas'][0], "embedding": results['embeddings'][0] if results['embeddings'] else None } return None except Exception as e: logger.error(f"Failed to retrieve chunk {chunk_id}: {e}") return None def delete_document(self, document_id: str): """ Delete all chunks for a document. Args: document_id: UUID of document to delete """ collection = self.get_collection() try: collection.delete( where={"document_id": document_id} ) logger.info(f"Deleted all chunks for document: {document_id}") except Exception as e: logger.error(f"Failed to delete document {document_id}: {e}") raise def get_collection_stats(self) -> dict: """ Get statistics about the collection. Returns: dict: Collection statistics """ collection = self.get_collection() try: count = collection.count() return { "name": self.collection_name, "total_chunks": count, "persist_dir": self.persist_dir, "embedding_model": self._embedding_model, } except Exception as e: logger.error(f"Failed to get collection stats: {e}") return {} def list_all_collections(self) -> List[dict]: """ List all available collections with their stats. Returns: List[dict]: List of collection info dictionaries """ collections = [] settings = get_settings() for model_id, model_config in EMBEDDING_MODELS.items(): collection_name = get_collection_name_for_model( model_id, self._base_collection_name ) try: coll = self.client.get_collection(name=collection_name) count = coll.count() collections.append({ "collection_name": collection_name, "embedding_model": model_id, "model_name": model_config.get("name", model_id), "dimensions": model_config.get("dimensions"), "total_chunks": count, "is_active": model_id == self._embedding_model, }) except Exception: # Collection doesn't exist yet collections.append({ "collection_name": collection_name, "embedding_model": model_id, "model_name": model_config.get("name", model_id), "dimensions": model_config.get("dimensions"), "total_chunks": 0, "is_active": model_id == self._embedding_model, }) return collections def switch_collection(self, embedding_model: str): """ Switch to a different collection based on embedding model. Args: embedding_model: Embedding model ID to switch to """ self._embedding_model = embedding_model self.collection_name = get_collection_name_for_model( embedding_model, self._base_collection_name ) self._collection = None # Reset cached collection logger.info(f"Switched to collection: {self.collection_name}") def query( self, query_embedding: np.ndarray, top_k: int = 10, filter_filenames: Optional[List[str]] = None, ) -> dict: """ Query the collection with an embedding. Args: query_embedding: Query embedding vector top_k: Number of results to return filter_filenames: Optional list of filenames to filter results Returns: dict: Query results with ids, documents, metadatas, and distances """ collection = self.get_collection() try: # Build where clause for filtering where_clause = None if filter_filenames: if len(filter_filenames) == 1: where_clause = {"filename": filter_filenames[0]} else: where_clause = {"filename": {"$in": filter_filenames}} results = collection.query( query_embeddings=[query_embedding.tolist()], n_results=top_k, include=["documents", "metadatas", "distances"], where=where_clause, ) return results except Exception as e: logger.error(f"Query failed: {e}") return {"ids": [], "documents": [], "metadatas": [], "distances": []}