"""ChromaDB and Qdrant integration for vector storage and retrieval.""" from typing import List, Dict, Optional, Tuple import chromadb from chromadb.config import Settings import uuid import os from embedding_models import EmbeddingFactory, EmbeddingModel from chunking_strategies import ChunkingFactory import json # Qdrant imports (optional - for cloud deployment) try: from qdrant_client import QdrantClient from qdrant_client.http import models as qdrant_models from qdrant_client.http.models import Distance, VectorParams, PointStruct QDRANT_AVAILABLE = True except ImportError: QDRANT_AVAILABLE = False print("Warning: qdrant-client not installed. Qdrant support disabled.") class ChromaDBManager: """Manager for ChromaDB operations.""" def __init__(self, persist_directory: str = "./chroma_db"): """Initialize ChromaDB manager. Args: persist_directory: Directory to persist ChromaDB data """ self.persist_directory = persist_directory os.makedirs(persist_directory, exist_ok=True) # Initialize ChromaDB client with is_persistent=True to use persistent storage try: self.client = chromadb.PersistentClient( path=persist_directory, settings=Settings( anonymized_telemetry=False, allow_reset=True # Allow reset if needed ) ) except Exception as e: print(f"Warning: Could not create persistent client: {e}") print("Falling back to regular client...") self.client = chromadb.Client(Settings( persist_directory=persist_directory, anonymized_telemetry=False, allow_reset=True )) self.embedding_model = None self.current_collection = None # Track evaluation-related metadata for reproducibility self.chunking_strategy = None self.chunk_size = None self.chunk_overlap = None def reconnect(self): """Reconnect to ChromaDB in case of connection loss.""" try: self.client = chromadb.PersistentClient( path=self.persist_directory, settings=Settings( anonymized_telemetry=False, allow_reset=True ) ) print("✅ Reconnected to ChromaDB") except Exception as e: print(f"Error reconnecting: {e}") def create_collection( self, collection_name: str, embedding_model_name: str, metadata: Optional[Dict] = None ) -> chromadb.Collection: """Create a new collection. Args: collection_name: Name of the collection embedding_model_name: Name of the embedding model metadata: Additional metadata for the collection Returns: ChromaDB collection """ # Delete if exists try: self.client.delete_collection(collection_name) except: pass # Create embedding model self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) self.embedding_model.load_model() # Create collection with metadata collection_metadata = { "embedding_model": embedding_model_name, "hnsw:space": "cosine" } if metadata: collection_metadata.update(metadata) self.current_collection = self.client.create_collection( name=collection_name, metadata=collection_metadata ) print(f"Created collection: {collection_name}") return self.current_collection def get_collection(self, collection_name: str) -> chromadb.Collection: """Get an existing collection. Args: collection_name: Name of the collection Returns: ChromaDB collection """ self.current_collection = self.client.get_collection(collection_name) # Load embedding model from metadata metadata = self.current_collection.metadata if "embedding_model" in metadata: self.embedding_model = EmbeddingFactory.create_embedding_model( metadata["embedding_model"] ) self.embedding_model.load_model() # Restore chunking metadata for evaluation reproducibility if "chunking_strategy" in metadata: self.chunking_strategy = metadata["chunking_strategy"] if "chunk_size" in metadata: self.chunk_size = metadata["chunk_size"] if "overlap" in metadata: self.chunk_overlap = metadata["overlap"] return self.current_collection def list_collections(self) -> List[str]: """List all collections. Returns: List of collection names """ collections = self.client.list_collections() return [col.name for col in collections] def clear_all_collections(self) -> int: """Delete all collections from the database. Returns: Number of collections deleted """ collections = self.list_collections() count = 0 for collection_name in collections: try: self.client.delete_collection(collection_name) print(f"Deleted collection: {collection_name}") count += 1 except Exception as e: print(f"Error deleting collection {collection_name}: {e}") self.current_collection = None self.embedding_model = None print(f"✅ Cleared {count} collections") return count def delete_collection(self, collection_name: str) -> bool: """Delete a specific collection. Args: collection_name: Name of the collection to delete Returns: True if deleted successfully, False otherwise """ try: self.client.delete_collection(collection_name) if self.current_collection and self.current_collection.name == collection_name: self.current_collection = None self.embedding_model = None print(f"✅ Deleted collection: {collection_name}") return True except Exception as e: print(f"❌ Error deleting collection: {e}") return False def add_documents( self, documents: List[str], metadatas: Optional[List[Dict]] = None, ids: Optional[List[str]] = None, batch_size: int = 100 ): """Add documents to the current collection. Args: documents: List of document texts metadatas: List of metadata dictionaries ids: List of document IDs batch_size: Batch size for processing """ if not self.current_collection: raise ValueError("No collection selected. Create or get a collection first.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate IDs if not provided if ids is None: ids = [str(uuid.uuid4()) for _ in documents] # Generate default metadata if not provided if metadatas is None: metadatas = [{"index": i} for i in range(len(documents))] # Process in batches total_docs = len(documents) print(f"Adding {total_docs} documents to collection...") for i in range(0, total_docs, batch_size): batch_docs = documents[i:i + batch_size] batch_ids = ids[i:i + batch_size] batch_metadatas = metadatas[i:i + batch_size] # Generate embeddings embeddings = self.embedding_model.embed_documents(batch_docs) # Add to collection self.current_collection.add( documents=batch_docs, embeddings=embeddings.tolist(), metadatas=batch_metadatas, ids=batch_ids ) print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") print(f"Successfully added {total_docs} documents") def load_dataset_into_collection( self, collection_name: str, embedding_model_name: str, chunking_strategy: str, dataset_data: List[Dict], chunk_size: int = 512, overlap: int = 50 ): """Load a dataset into a new collection with chunking. Args: collection_name: Name for the new collection embedding_model_name: Embedding model to use chunking_strategy: Chunking strategy to use dataset_data: List of dataset samples chunk_size: Size of chunks overlap: Overlap between chunks """ # Store metadata for later evaluation reference self.chunking_strategy = chunking_strategy self.chunk_size = chunk_size self.chunk_overlap = overlap # Create collection self.create_collection( collection_name, embedding_model_name, metadata={ "chunking_strategy": chunking_strategy, "chunk_size": chunk_size, "overlap": overlap } ) # Get chunker chunker = ChunkingFactory.create_chunker(chunking_strategy) # Process documents all_chunks = [] all_metadatas = [] print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") for idx, sample in enumerate(dataset_data): # Use 'documents' list if available, otherwise fall back to 'context' documents = sample.get("documents", []) # If documents is empty, use context as fallback if not documents: context = sample.get("context", "") if context: documents = [context] if not documents: continue # Process each document separately for better granularity for doc_idx, document in enumerate(documents): if not document or not str(document).strip(): continue # Chunk each document chunks = chunker.chunk_text(str(document), chunk_size, overlap) # Create metadata for each chunk for chunk_idx, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append({ "doc_id": idx, "doc_idx": doc_idx, # Track which document within the sample "chunk_id": chunk_idx, "question": sample.get("question", ""), "answer": sample.get("answer", ""), "dataset": sample.get("dataset", ""), "total_docs": len(documents) }) # Add all chunks to collection self.add_documents(all_chunks, all_metadatas) print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") def query( self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict] = None ) -> Dict: """Query the collection. Args: query_text: Query text n_results: Number of results to return filter_metadata: Metadata filter Returns: Query results """ if not self.current_collection: raise ValueError("No collection selected.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate query embedding query_embedding = self.embedding_model.embed_query(query_text) # Query collection with retry logic try: results = self.current_collection.query( query_embeddings=[query_embedding.tolist()], n_results=n_results, where=filter_metadata ) except Exception as e: if "default_tenant" in str(e): print("Warning: Lost connection to ChromaDB, reconnecting...") self.reconnect() # Try again after reconnecting results = self.current_collection.query( query_embeddings=[query_embedding.tolist()], n_results=n_results, where=filter_metadata ) else: raise return results def get_retrieved_documents( self, query_text: str, n_results: int = 5 ) -> List[Dict]: """Get retrieved documents with metadata. Args: query_text: Query text n_results: Number of results Returns: List of retrieved documents with metadata """ results = self.query(query_text, n_results) retrieved_docs = [] for i in range(len(results['documents'][0])): retrieved_docs.append({ "document": results['documents'][0][i], "metadata": results['metadatas'][0][i], "distance": results['distances'][0][i] if 'distances' in results else None }) return retrieved_docs def delete_collection(self, collection_name: str): """Delete a collection. Args: collection_name: Name of collection to delete """ try: self.client.delete_collection(collection_name) print(f"Deleted collection: {collection_name}") except Exception as e: print(f"Error deleting collection: {str(e)}") def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: """Get statistics for a collection. Args: collection_name: Name of collection (uses current if None) Returns: Dictionary with collection statistics """ if collection_name: collection = self.client.get_collection(collection_name) elif self.current_collection: collection = self.current_collection else: raise ValueError("No collection specified or selected") count = collection.count() metadata = collection.metadata return { "name": collection.name, "count": count, "metadata": metadata } class QdrantManager: """Manager for Qdrant Cloud operations - persistent storage for HuggingFace Spaces.""" def __init__(self, url: str = None, api_key: str = None): """Initialize Qdrant client. Args: url: Qdrant Cloud URL (e.g., https://xxx.qdrant.io) api_key: Qdrant API key """ if not QDRANT_AVAILABLE: raise ImportError("qdrant-client is not installed. Run: pip install qdrant-client") self.url = url or os.environ.get("QDRANT_URL", "") self.api_key = api_key or os.environ.get("QDRANT_API_KEY", "") if not self.url or not self.api_key: raise ValueError("QDRANT_URL and QDRANT_API_KEY are required") self.client = QdrantClient( url=self.url, api_key=self.api_key, timeout=60 ) self.embedding_model = None self.current_collection = None self.vector_size = None self.chunking_strategy = None self.chunk_size = None self.chunk_overlap = None print(f"[QDRANT] Connected to Qdrant Cloud at {self.url}") def create_collection( self, collection_name: str, embedding_model_name: str, metadata: Optional[Dict] = None ): """Create a new collection in Qdrant. Args: collection_name: Name of the collection embedding_model_name: Name of the embedding model metadata: Additional metadata for the collection """ # Create embedding model to get vector size self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) self.embedding_model.load_model() # Get vector size from a sample embedding sample_embedding = self.embedding_model.embed_query("test") self.vector_size = len(sample_embedding) # Delete if exists try: self.client.delete_collection(collection_name) print(f"[QDRANT] Deleted existing collection: {collection_name}") except: pass # Create collection self.client.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=self.vector_size, distance=Distance.COSINE ) ) self.current_collection = collection_name print(f"[QDRANT] Created collection: {collection_name} (vector_size={self.vector_size})") return self.current_collection def get_collection(self, collection_name: str): """Get an existing collection. Args: collection_name: Name of the collection """ # Verify collection exists collections = self.list_collections() if collection_name not in collections: raise ValueError(f"Collection '{collection_name}' not found") self.current_collection = collection_name # Get collection info to determine embedding model info = self.client.get_collection(collection_name) self.vector_size = info.config.params.vectors.size # Try to load embedding model from first document's metadata embedding_model_name = None try: # Scroll to get first point points, _ = self.client.scroll( collection_name=collection_name, limit=1, with_payload=True ) if points and len(points) > 0: payload = points[0].payload embedding_model_name = payload.get("embedding_model") if "chunking_strategy" in payload: self.chunking_strategy = payload["chunking_strategy"] except Exception as e: print(f"[QDRANT] Warning: Could not retrieve metadata: {e}") # If not found in metadata, try to infer from collection name if not embedding_model_name: # Collection name format: dataset_strategy_modelname # Try common embedding models known_models = [ "all-mpnet-base-v2", "all-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2", "multi-qa-MiniLM-L6-cos-v1" ] for model in known_models: if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""): embedding_model_name = f"sentence-transformers/{model}" break # Default fallback if not embedding_model_name: embedding_model_name = "sentence-transformers/all-mpnet-base-v2" print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}") # Load the embedding model if embedding_model_name: self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) self.embedding_model.load_model() print(f"[QDRANT] Loaded embedding model: {embedding_model_name}") print(f"[QDRANT] Loaded collection: {collection_name}") return self.current_collection def list_collections(self) -> List[str]: """List all collections. Returns: List of collection names """ collections = self.client.get_collections() return [col.name for col in collections.collections] def add_documents( self, documents: List[str], metadatas: Optional[List[Dict]] = None, ids: Optional[List[str]] = None, batch_size: int = 100 ): """Add documents to the current collection. Args: documents: List of document texts metadatas: List of metadata dictionaries ids: List of document IDs batch_size: Batch size for processing """ if not self.current_collection: raise ValueError("No collection selected. Create or get a collection first.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate IDs if not provided if ids is None: ids = [str(uuid.uuid4()) for _ in documents] # Generate default metadata if not provided if metadatas is None: metadatas = [{"index": i} for i in range(len(documents))] # Process in batches total_docs = len(documents) print(f"[QDRANT] Adding {total_docs} documents to collection...") for i in range(0, total_docs, batch_size): batch_docs = documents[i:i + batch_size] batch_ids = ids[i:i + batch_size] batch_metadatas = metadatas[i:i + batch_size] # Generate embeddings embeddings = self.embedding_model.embed_documents(batch_docs) # Create points points = [] for j, (doc, embedding, meta, doc_id) in enumerate(zip(batch_docs, embeddings, batch_metadatas, batch_ids)): # Store document text in payload payload = {**meta, "text": doc} points.append(PointStruct( id=i + j, # Use integer ID vector=embedding.tolist(), payload=payload )) # Upsert to collection self.client.upsert( collection_name=self.current_collection, points=points ) print(f"[QDRANT] Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") print(f"[QDRANT] Successfully added {total_docs} documents") def load_dataset_into_collection( self, collection_name: str, embedding_model_name: str, chunking_strategy: str, dataset_data: List[Dict], chunk_size: int = 512, overlap: int = 50 ): """Load a dataset into a new collection with chunking. Args: collection_name: Name for the new collection embedding_model_name: Embedding model to use chunking_strategy: Chunking strategy to use dataset_data: List of dataset samples chunk_size: Size of chunks overlap: Overlap between chunks """ self.chunking_strategy = chunking_strategy self.chunk_size = chunk_size self.chunk_overlap = overlap # Create collection self.create_collection(collection_name, embedding_model_name) # Get chunker chunker = ChunkingFactory.create_chunker(chunking_strategy) # Process documents all_chunks = [] all_metadatas = [] print(f"[QDRANT] Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") for idx, sample in enumerate(dataset_data): documents = sample.get("documents", []) if not documents: context = sample.get("context", "") if context: documents = [context] if not documents: continue for doc_idx, document in enumerate(documents): if not document or not str(document).strip(): continue chunks = chunker.chunk_text(str(document), chunk_size, overlap) for chunk_idx, chunk in enumerate(chunks): all_chunks.append(chunk) all_metadatas.append({ "doc_id": idx, "doc_idx": doc_idx, "chunk_id": chunk_idx, "question": sample.get("question", ""), "answer": sample.get("answer", ""), "dataset": sample.get("dataset", ""), "total_docs": len(documents), "embedding_model": embedding_model_name, "chunking_strategy": chunking_strategy }) # Add all chunks to collection self.add_documents(all_chunks, all_metadatas) print(f"[QDRANT] Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") def query( self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict] = None ) -> Dict: """Query the collection. Args: query_text: Query text n_results: Number of results to return filter_metadata: Metadata filter Returns: Query results in ChromaDB-compatible format """ if not self.current_collection: raise ValueError("No collection selected.") if not self.embedding_model: raise ValueError("No embedding model loaded.") # Generate query embedding query_embedding = self.embedding_model.embed_query(query_text) # Query Qdrant using query_points (newer API) or search (older API) try: # Try newer API first (qdrant-client >= 1.7) from qdrant_client.http.models import QueryRequest results = self.client.query_points( collection_name=self.current_collection, query=query_embedding.tolist(), limit=n_results, with_payload=True ).points except (AttributeError, ImportError): # Fallback to older API results = self.client.search( collection_name=self.current_collection, query_vector=query_embedding.tolist(), limit=n_results ) # Convert to ChromaDB-compatible format documents = [[r.payload.get("text", "") for r in results]] metadatas = [[{k: v for k, v in r.payload.items() if k != "text"} for r in results]] distances = [[1 - r.score for r in results]] # Convert similarity to distance return { "documents": documents, "metadatas": metadatas, "distances": distances } def get_retrieved_documents( self, query_text: str, n_results: int = 5 ) -> List[Dict]: """Get retrieved documents with metadata. Args: query_text: Query text n_results: Number of results Returns: List of retrieved documents with metadata """ results = self.query(query_text, n_results) retrieved_docs = [] for i in range(len(results['documents'][0])): retrieved_docs.append({ "document": results['documents'][0][i], "metadata": results['metadatas'][0][i], "distance": results['distances'][0][i] if 'distances' in results else None }) return retrieved_docs def delete_collection(self, collection_name: str) -> bool: """Delete a specific collection. Args: collection_name: Name of the collection to delete Returns: True if deleted successfully, False otherwise """ try: self.client.delete_collection(collection_name) if self.current_collection == collection_name: self.current_collection = None self.embedding_model = None print(f"[QDRANT] Deleted collection: {collection_name}") return True except Exception as e: print(f"[QDRANT] Error deleting collection: {e}") return False def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: """Get statistics for a collection. Args: collection_name: Name of collection (uses current if None) Returns: Dictionary with collection statistics """ coll_name = collection_name or self.current_collection if not coll_name: raise ValueError("No collection specified or selected") info = self.client.get_collection(coll_name) return { "name": coll_name, "count": info.points_count, "vector_size": info.config.params.vectors.size, "status": info.status } def create_vector_store(provider: str = "chroma", **kwargs): """Factory function to create vector store manager. Args: provider: "chroma" or "qdrant" **kwargs: Provider-specific arguments Returns: ChromaDBManager or QdrantManager instance """ if provider == "qdrant": if not QDRANT_AVAILABLE: raise ImportError("qdrant-client not installed. Run: pip install qdrant-client") return QdrantManager( url=kwargs.get("url") or os.environ.get("QDRANT_URL"), api_key=kwargs.get("api_key") or os.environ.get("QDRANT_API_KEY") ) else: return ChromaDBManager( persist_directory=kwargs.get("persist_directory", "./chroma_db") )