"""ChromaDB integration for vector storage and retrieval.""" from typing import List, Dict, Optional, Tuple import chromadb from chromadb.config import Settings import uuid import os from embedding_models import EmbeddingFactory, EmbeddingModel from chunking_strategies import ChunkingFactory import json class ChromaDBManager: """Manager for ChromaDB operations.""" def __init__(self, persist_directory: str = "./chroma_db"): """Initialize ChromaDB manager. Args: persist_directory: Directory to persist ChromaDB data """ self.persist_directory = persist_directory os.makedirs(persist_directory, exist_ok=True) # Initialize ChromaDB client with is_persistent=True to use persistent storage try: self.client = chromadb.PersistentClient( path=persist_directory, settings=Settings( anonymized_telemetry=False, allow_reset=True # Allow reset if needed ) ) except Exception as e: print(f"Warning: Could not create persistent client: {e}") print("Falling back to regular client...") self.client = chromadb.Client(Settings( persist_directory=persist_directory, anonymized_telemetry=False, allow_reset=True )) self.embedding_model = None self.current_collection = None def reconnect(self): """Reconnect to ChromaDB in case of connection loss.""" try: self.client = chromadb.PersistentClient( path=self.persist_directory, settings=Settings( anonymized_telemetry=False, allow_reset=True ) ) print("✅ Reconnected to ChromaDB") except Exception as e: print(f"Error reconnecting: {e}") def create_collection( self, collection_name: str, embedding_model_name: str, metadata: Optional[Dict] = None ) -> chromadb.Collection: """Create a new collection. Args: collection_name: Name of the collection embedding_model_name: Name of the embedding model metadata: Additional metadata for the collection Returns: ChromaDB collection """ # Delete if exists try: self.client.delete_collection(collection_name) except: pass # Create embedding model self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) self.embedding_model.load_model() # Create collection with metadata collection_metadata = { "embedding_model": embedding_model_name, "hnsw:space": "cosine" } if metadata: collection_metadata.update(metadata) self.current_collection = self.client.create_collection( name=collection_name, metadata=collection_metadata ) print(f"Created collection: {collection_name}") return self.current_collection def get_collection(self, collection_name: str) -> chromadb.Collection: """Get an existing collection. Args: collection_name: Name of the collection Returns: ChromaDB collection """ self.current_collection = self.client.get_collection(collection_name) # Load embedding model from metadata metadata = self.current_collection.metadata if "embedding_model" in metadata: self.embedding_model = EmbeddingFactory.create_embedding_model( metadata["embedding_model"] ) self.embedding_model.load_model() return self.current_collection def list_collections(self) -> List[str]: """List all collections. Returns: List of collection names """ collections = self.client.list_collections() return [col.name for col in collections] def clear_all_collections(self) -> int: """Delete all collections from the database. Returns: Number of collections deleted """ collections = self.list_collections() count = 0 for collection_name in collections: try: self.client.delete_collection(collection_name) print(f"Deleted collection: {collection_name}") count += 1 except Exception as e: print(f"Error deleting collection {collection_name}: {e}") self.current_collection = None self.embedding_model = None print(f"✅ Cleared {count} collections") return count def delete_collection(self, collection_name: str) -> bool: """Delete a specific collection. Args: collection_name: Name of the collection to delete Returns: True if deleted successfully, False otherwise """ try: self.client.delete_collection(collection_name) if self.current_collection and self.current_collection.name == collection_name: self.current_collection = None self.embedding_model = None print(f"✅ Deleted collection: {collection_name}") return True except Exception as e: print(f"❌ Error deleting collection: {e}") return False def add_documents( self, documents: List[str], metadatas: Optional[List[Dict]] = None, ids: Optional[List[str]] = None, batch_size: int = 100 ): """Add documents to the current collection. Args: documents: List of document texts metadatas: List of metadata dictionaries ids: List of document IDs batch_size: Batch size for processing """ if not self.current_collection: raise ValueError("No collection selected. Create or get a collection first.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate IDs if not provided if ids is None: ids = [str(uuid.uuid4()) for _ in documents] # Generate default metadata if not provided if metadatas is None: metadatas = [{"index": i} for i in range(len(documents))] # Process in batches total_docs = len(documents) print(f"Adding {total_docs} documents to collection...") for i in range(0, total_docs, batch_size): batch_docs = documents[i:i + batch_size] batch_ids = ids[i:i + batch_size] batch_metadatas = metadatas[i:i + batch_size] # Generate embeddings embeddings = self.embedding_model.embed_documents(batch_docs) # Add to collection self.current_collection.add( documents=batch_docs, embeddings=embeddings.tolist(), metadatas=batch_metadatas, ids=batch_ids ) print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") print(f"Successfully added {total_docs} documents") def load_dataset_into_collection( self, collection_name: str, embedding_model_name: str, chunking_strategy: str, dataset_data: List[Dict], chunk_size: int = 512, overlap: int = 50 ): """Load a dataset into a new collection with chunking. Args: collection_name: Name for the new collection embedding_model_name: Embedding model to use chunking_strategy: Chunking strategy to use dataset_data: List of dataset samples chunk_size: Size of chunks overlap: Overlap between chunks """ # Create collection self.create_collection( collection_name, embedding_model_name, metadata={ "chunking_strategy": chunking_strategy, "chunk_size": chunk_size, "overlap": overlap } ) # Get chunker chunker = ChunkingFactory.create_chunker(chunking_strategy) # Process documents all_chunks = [] all_metadatas = [] print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") for idx, sample in enumerate(dataset_data): # Use 'documents' list if available, otherwise fall back to 'context' documents = sample.get("documents", []) # If documents is empty, use context as fallback if not documents: context = sample.get("context", "") if context: documents = [context] if not documents: continue # Process each document separately for better granularity for doc_idx, document in enumerate(documents): if not document or not str(document).strip(): continue # Chunk each document chunks = chunker.chunk_text(str(document), chunk_size, overlap) # Create metadata for each chunk for chunk_idx, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append({ "doc_id": idx, "doc_idx": doc_idx, # Track which document within the sample "chunk_id": chunk_idx, "question": sample.get("question", ""), "answer": sample.get("answer", ""), "dataset": sample.get("dataset", ""), "total_docs": len(documents) }) # Add all chunks to collection self.add_documents(all_chunks, all_metadatas) print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") def query( self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict] = None ) -> Dict: """Query the collection. Args: query_text: Query text n_results: Number of results to return filter_metadata: Metadata filter Returns: Query results """ if not self.current_collection: raise ValueError("No collection selected.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate query embedding query_embedding = self.embedding_model.embed_query(query_text) # Query collection with retry logic try: results = self.current_collection.query( query_embeddings=[query_embedding.tolist()], n_results=n_results, where=filter_metadata ) except Exception as e: if "default_tenant" in str(e): print("Warning: Lost connection to ChromaDB, reconnecting...") self.reconnect() # Try again after reconnecting results = self.current_collection.query( query_embeddings=[query_embedding.tolist()], n_results=n_results, where=filter_metadata ) else: raise return results def get_retrieved_documents( self, query_text: str, n_results: int = 5 ) -> List[Dict]: """Get retrieved documents with metadata. Args: query_text: Query text n_results: Number of results Returns: List of retrieved documents with metadata """ results = self.query(query_text, n_results) retrieved_docs = [] for i in range(len(results['documents'][0])): retrieved_docs.append({ "document": results['documents'][0][i], "metadata": results['metadatas'][0][i], "distance": results['distances'][0][i] if 'distances' in results else None }) return retrieved_docs def delete_collection(self, collection_name: str): """Delete a collection. Args: collection_name: Name of collection to delete """ try: self.client.delete_collection(collection_name) print(f"Deleted collection: {collection_name}") except Exception as e: print(f"Error deleting collection: {str(e)}") def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: """Get statistics for a collection. Args: collection_name: Name of collection (uses current if None) Returns: Dictionary with collection statistics """ if collection_name: collection = self.client.get_collection(collection_name) elif self.current_collection: collection = self.current_collection else: raise ValueError("No collection specified or selected") count = collection.count() metadata = collection.metadata return { "name": collection.name, "count": count, "metadata": metadata }