| | """ |
| | ChromaDB vector storage interface. |
| | |
| | This module provides a clean interface to ChromaDB for storing and retrieving |
| | document chunks with their embeddings and metadata. |
| | """ |
| |
|
| | import chromadb |
| | from typing import List, Optional |
| | import numpy as np |
| | import json |
| | from datetime import datetime |
| | from src.config.settings import get_settings, get_collection_name_for_model, EMBEDDING_MODELS |
| | from src.utils.logging import get_logger |
| | from src.ingestion.models import Chunk |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class VectorStore: |
| | """ChromaDB interface for vector storage.""" |
| |
|
| | def __init__(self, embedding_model: Optional[str] = None): |
| | """ |
| | Initialize vector store with settings from configuration. |
| | |
| | Args: |
| | embedding_model: Optional embedding model ID. If provided, uses model-specific collection. |
| | """ |
| | settings = get_settings() |
| | self.persist_dir = settings.chroma_persist_dir |
| | self._base_collection_name = settings.chroma_collection_name |
| | self._embedding_model = embedding_model or settings.embedding_model |
| |
|
| | |
| | self.collection_name = get_collection_name_for_model( |
| | self._embedding_model, |
| | self._base_collection_name |
| | ) |
| |
|
| | self._client = None |
| | self._collection = None |
| |
|
| | @property |
| | def client(self): |
| | """ |
| | Lazy initialize ChromaDB client. |
| | |
| | Returns: |
| | chromadb.Client: ChromaDB client instance |
| | """ |
| | if self._client is None: |
| | logger.info(f"Initializing ChromaDB client: {self.persist_dir}") |
| | self._client = chromadb.PersistentClient(path=self.persist_dir) |
| | logger.debug(f"ChromaDB client initialized") |
| | return self._client |
| |
|
| | def get_collection(self): |
| | """ |
| | Get or create the collection. |
| | |
| | Returns: |
| | chromadb.Collection: Collection instance |
| | """ |
| | if self._collection is None: |
| | self._collection = self.client.get_or_create_collection( |
| | name=self.collection_name, |
| | metadata={"description": "Hierarchical PDF chunks with embeddings"} |
| | ) |
| | logger.info(f"Collection loaded: {self.collection_name}") |
| | return self._collection |
| |
|
| | def add_chunks(self, chunks: List[Chunk], embeddings: np.ndarray): |
| | """ |
| | Add chunks with embeddings to ChromaDB. |
| | |
| | Args: |
| | chunks: List of chunks to store |
| | embeddings: Numpy array of embeddings (num_chunks x embedding_dim) |
| | """ |
| | if len(chunks) != len(embeddings): |
| | raise ValueError(f"Number of chunks ({len(chunks)}) != number of embeddings ({len(embeddings)})") |
| |
|
| | collection = self.get_collection() |
| |
|
| | |
| | ids = [str(chunk.chunk_id) for chunk in chunks] |
| | documents = [chunk.text for chunk in chunks] |
| | metadatas = [self._prepare_metadata(chunk) for chunk in chunks] |
| |
|
| | logger.info(f"Adding {len(chunks)} chunks to ChromaDB") |
| |
|
| | |
| | collection.add( |
| | ids=ids, |
| | embeddings=embeddings.tolist(), |
| | documents=documents, |
| | metadatas=metadatas |
| | ) |
| |
|
| | logger.info(f"Successfully added {len(chunks)} chunks") |
| |
|
| | def _prepare_metadata(self, chunk: Chunk) -> dict: |
| | """ |
| | Prepare metadata for ChromaDB storage. |
| | |
| | ChromaDB metadata can only contain: str, int, float, bool. |
| | Lists must be JSON-encoded. |
| | |
| | Args: |
| | chunk: Chunk to extract metadata from |
| | |
| | Returns: |
| | dict: Metadata dictionary |
| | """ |
| | return { |
| | "chunk_id": str(chunk.chunk_id), |
| | "document_id": str(chunk.document_id), |
| | "parent_id": str(chunk.parent_id) if chunk.parent_id else "", |
| | "chunk_type": chunk.chunk_type, |
| | "token_count": chunk.token_count, |
| | "chunk_index": chunk.chunk_index, |
| | "page_numbers": json.dumps(chunk.page_numbers), |
| | "start_char": chunk.start_char, |
| | "end_char": chunk.end_char, |
| | "file_hash": chunk.file_hash, |
| | "filename": chunk.filename, |
| | } |
| |
|
| | def document_exists(self, file_hash: str) -> bool: |
| | """ |
| | Check if document with given hash already exists. |
| | |
| | Args: |
| | file_hash: SHA256 hash of document |
| | |
| | Returns: |
| | bool: True if document exists |
| | """ |
| | collection = self.get_collection() |
| |
|
| | try: |
| | |
| | results = collection.get( |
| | where={"file_hash": file_hash}, |
| | limit=1 |
| | ) |
| | exists = len(results['ids']) > 0 |
| | if exists: |
| | logger.debug(f"Document with hash {file_hash[:8]}... already exists") |
| | return exists |
| | except Exception as e: |
| | |
| | logger.debug(f"Document check failed: {e}") |
| | return False |
| |
|
| | def get_chunk(self, chunk_id: str) -> Optional[dict]: |
| | """ |
| | Retrieve a specific chunk by ID. |
| | |
| | Args: |
| | chunk_id: UUID of chunk to retrieve |
| | |
| | Returns: |
| | Optional[dict]: Chunk data or None if not found |
| | """ |
| | collection = self.get_collection() |
| |
|
| | try: |
| | results = collection.get( |
| | ids=[chunk_id], |
| | include=["documents", "metadatas", "embeddings"] |
| | ) |
| |
|
| | if len(results['ids']) > 0: |
| | return { |
| | "id": results['ids'][0], |
| | "document": results['documents'][0], |
| | "metadata": results['metadatas'][0], |
| | "embedding": results['embeddings'][0] if results['embeddings'] else None |
| | } |
| | return None |
| | except Exception as e: |
| | logger.error(f"Failed to retrieve chunk {chunk_id}: {e}") |
| | return None |
| |
|
| | def delete_document(self, document_id: str): |
| | """ |
| | Delete all chunks for a document. |
| | |
| | Args: |
| | document_id: UUID of document to delete |
| | """ |
| | collection = self.get_collection() |
| |
|
| | try: |
| | collection.delete( |
| | where={"document_id": document_id} |
| | ) |
| | logger.info(f"Deleted all chunks for document: {document_id}") |
| | except Exception as e: |
| | logger.error(f"Failed to delete document {document_id}: {e}") |
| | raise |
| |
|
| | def get_collection_stats(self) -> dict: |
| | """ |
| | Get statistics about the collection. |
| | |
| | Returns: |
| | dict: Collection statistics |
| | """ |
| | collection = self.get_collection() |
| |
|
| | try: |
| | count = collection.count() |
| | return { |
| | "name": self.collection_name, |
| | "total_chunks": count, |
| | "persist_dir": self.persist_dir, |
| | "embedding_model": self._embedding_model, |
| | } |
| | except Exception as e: |
| | logger.error(f"Failed to get collection stats: {e}") |
| | return {} |
| |
|
| | def list_all_collections(self) -> List[dict]: |
| | """ |
| | List all available collections with their stats. |
| | |
| | Returns: |
| | List[dict]: List of collection info dictionaries |
| | """ |
| | collections = [] |
| | settings = get_settings() |
| |
|
| | for model_id, model_config in EMBEDDING_MODELS.items(): |
| | collection_name = get_collection_name_for_model( |
| | model_id, |
| | self._base_collection_name |
| | ) |
| | try: |
| | coll = self.client.get_collection(name=collection_name) |
| | count = coll.count() |
| | collections.append({ |
| | "collection_name": collection_name, |
| | "embedding_model": model_id, |
| | "model_name": model_config.get("name", model_id), |
| | "dimensions": model_config.get("dimensions"), |
| | "total_chunks": count, |
| | "is_active": model_id == self._embedding_model, |
| | }) |
| | except Exception: |
| | |
| | collections.append({ |
| | "collection_name": collection_name, |
| | "embedding_model": model_id, |
| | "model_name": model_config.get("name", model_id), |
| | "dimensions": model_config.get("dimensions"), |
| | "total_chunks": 0, |
| | "is_active": model_id == self._embedding_model, |
| | }) |
| |
|
| | return collections |
| |
|
| | def switch_collection(self, embedding_model: str): |
| | """ |
| | Switch to a different collection based on embedding model. |
| | |
| | Args: |
| | embedding_model: Embedding model ID to switch to |
| | """ |
| | self._embedding_model = embedding_model |
| | self.collection_name = get_collection_name_for_model( |
| | embedding_model, |
| | self._base_collection_name |
| | ) |
| | self._collection = None |
| | logger.info(f"Switched to collection: {self.collection_name}") |
| |
|
| | def query( |
| | self, |
| | query_embedding: np.ndarray, |
| | top_k: int = 10, |
| | filter_filenames: Optional[List[str]] = None, |
| | ) -> dict: |
| | """ |
| | Query the collection with an embedding. |
| | |
| | Args: |
| | query_embedding: Query embedding vector |
| | top_k: Number of results to return |
| | filter_filenames: Optional list of filenames to filter results |
| | |
| | Returns: |
| | dict: Query results with ids, documents, metadatas, and distances |
| | """ |
| | collection = self.get_collection() |
| |
|
| | try: |
| | |
| | where_clause = None |
| | if filter_filenames: |
| | if len(filter_filenames) == 1: |
| | where_clause = {"filename": filter_filenames[0]} |
| | else: |
| | where_clause = {"filename": {"$in": filter_filenames}} |
| |
|
| | results = collection.query( |
| | query_embeddings=[query_embedding.tolist()], |
| | n_results=top_k, |
| | include=["documents", "metadatas", "distances"], |
| | where=where_clause, |
| | ) |
| | return results |
| | except Exception as e: |
| | logger.error(f"Query failed: {e}") |
| | return {"ids": [], "documents": [], "metadatas": [], "distances": []} |
| |
|