Spaces:
Sleeping
Sleeping
| """ChromaDB and Qdrant integration for vector storage and retrieval.""" | |
| from typing import List, Dict, Optional, Tuple | |
| import chromadb | |
| from chromadb.config import Settings | |
| import uuid | |
| import os | |
| from embedding_models import EmbeddingFactory, EmbeddingModel | |
| from chunking_strategies import ChunkingFactory | |
| import json | |
| # Qdrant imports (optional - for cloud deployment) | |
| try: | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models as qdrant_models | |
| from qdrant_client.http.models import Distance, VectorParams, PointStruct | |
| QDRANT_AVAILABLE = True | |
| except ImportError: | |
| QDRANT_AVAILABLE = False | |
| print("Warning: qdrant-client not installed. Qdrant support disabled.") | |
| class ChromaDBManager: | |
| """Manager for ChromaDB operations.""" | |
| def __init__(self, persist_directory: str = "./chroma_db"): | |
| """Initialize ChromaDB manager. | |
| Args: | |
| persist_directory: Directory to persist ChromaDB data | |
| """ | |
| self.persist_directory = persist_directory | |
| os.makedirs(persist_directory, exist_ok=True) | |
| # Initialize ChromaDB client with is_persistent=True to use persistent storage | |
| try: | |
| self.client = chromadb.PersistentClient( | |
| path=persist_directory, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True # Allow reset if needed | |
| ) | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not create persistent client: {e}") | |
| print("Falling back to regular client...") | |
| self.client = chromadb.Client(Settings( | |
| persist_directory=persist_directory, | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| )) | |
| self.embedding_model = None | |
| self.current_collection = None | |
| # Track evaluation-related metadata for reproducibility | |
| self.chunking_strategy = None | |
| self.chunk_size = None | |
| self.chunk_overlap = None | |
| def reconnect(self): | |
| """Reconnect to ChromaDB in case of connection loss.""" | |
| try: | |
| self.client = chromadb.PersistentClient( | |
| path=self.persist_directory, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| ) | |
| ) | |
| print("✅ Reconnected to ChromaDB") | |
| except Exception as e: | |
| print(f"Error reconnecting: {e}") | |
| def create_collection( | |
| self, | |
| collection_name: str, | |
| embedding_model_name: str, | |
| metadata: Optional[Dict] = None | |
| ) -> chromadb.Collection: | |
| """Create a new collection. | |
| Args: | |
| collection_name: Name of the collection | |
| embedding_model_name: Name of the embedding model | |
| metadata: Additional metadata for the collection | |
| Returns: | |
| ChromaDB collection | |
| """ | |
| # Delete if exists | |
| try: | |
| self.client.delete_collection(collection_name) | |
| except: | |
| pass | |
| # Create embedding model | |
| self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) | |
| self.embedding_model.load_model() | |
| # Create collection with metadata | |
| collection_metadata = { | |
| "embedding_model": embedding_model_name, | |
| "hnsw:space": "cosine" | |
| } | |
| if metadata: | |
| collection_metadata.update(metadata) | |
| self.current_collection = self.client.create_collection( | |
| name=collection_name, | |
| metadata=collection_metadata | |
| ) | |
| print(f"Created collection: {collection_name}") | |
| return self.current_collection | |
| def get_collection(self, collection_name: str) -> chromadb.Collection: | |
| """Get an existing collection. | |
| Args: | |
| collection_name: Name of the collection | |
| Returns: | |
| ChromaDB collection | |
| """ | |
| self.current_collection = self.client.get_collection(collection_name) | |
| # Load embedding model from metadata | |
| metadata = self.current_collection.metadata | |
| if "embedding_model" in metadata: | |
| self.embedding_model = EmbeddingFactory.create_embedding_model( | |
| metadata["embedding_model"] | |
| ) | |
| self.embedding_model.load_model() | |
| # Restore chunking metadata for evaluation reproducibility | |
| if "chunking_strategy" in metadata: | |
| self.chunking_strategy = metadata["chunking_strategy"] | |
| if "chunk_size" in metadata: | |
| self.chunk_size = metadata["chunk_size"] | |
| if "overlap" in metadata: | |
| self.chunk_overlap = metadata["overlap"] | |
| return self.current_collection | |
| def list_collections(self) -> List[str]: | |
| """List all collections. | |
| Returns: | |
| List of collection names | |
| """ | |
| collections = self.client.list_collections() | |
| return [col.name for col in collections] | |
| def clear_all_collections(self) -> int: | |
| """Delete all collections from the database. | |
| Returns: | |
| Number of collections deleted | |
| """ | |
| collections = self.list_collections() | |
| count = 0 | |
| for collection_name in collections: | |
| try: | |
| self.client.delete_collection(collection_name) | |
| print(f"Deleted collection: {collection_name}") | |
| count += 1 | |
| except Exception as e: | |
| print(f"Error deleting collection {collection_name}: {e}") | |
| self.current_collection = None | |
| self.embedding_model = None | |
| print(f"✅ Cleared {count} collections") | |
| return count | |
| def delete_collection(self, collection_name: str) -> bool: | |
| """Delete a specific collection. | |
| Args: | |
| collection_name: Name of the collection to delete | |
| Returns: | |
| True if deleted successfully, False otherwise | |
| """ | |
| try: | |
| self.client.delete_collection(collection_name) | |
| if self.current_collection and self.current_collection.name == collection_name: | |
| self.current_collection = None | |
| self.embedding_model = None | |
| print(f"✅ Deleted collection: {collection_name}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error deleting collection: {e}") | |
| return False | |
| def add_documents( | |
| self, | |
| documents: List[str], | |
| metadatas: Optional[List[Dict]] = None, | |
| ids: Optional[List[str]] = None, | |
| batch_size: int = 100 | |
| ): | |
| """Add documents to the current collection. | |
| Args: | |
| documents: List of document texts | |
| metadatas: List of metadata dictionaries | |
| ids: List of document IDs | |
| batch_size: Batch size for processing | |
| """ | |
| if not self.current_collection: | |
| raise ValueError("No collection selected. Create or get a collection first.") | |
| if not self.embedding_model: | |
| raise ValueError("No embedding model loaded.") | |
| # Generate IDs if not provided | |
| if ids is None: | |
| ids = [str(uuid.uuid4()) for _ in documents] | |
| # Generate default metadata if not provided | |
| if metadatas is None: | |
| metadatas = [{"index": i} for i in range(len(documents))] | |
| # Process in batches | |
| total_docs = len(documents) | |
| print(f"Adding {total_docs} documents to collection...") | |
| for i in range(0, total_docs, batch_size): | |
| batch_docs = documents[i:i + batch_size] | |
| batch_ids = ids[i:i + batch_size] | |
| batch_metadatas = metadatas[i:i + batch_size] | |
| # Generate embeddings | |
| embeddings = self.embedding_model.embed_documents(batch_docs) | |
| # Add to collection | |
| self.current_collection.add( | |
| documents=batch_docs, | |
| embeddings=embeddings.tolist(), | |
| metadatas=batch_metadatas, | |
| ids=batch_ids | |
| ) | |
| print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") | |
| print(f"Successfully added {total_docs} documents") | |
| def load_dataset_into_collection( | |
| self, | |
| collection_name: str, | |
| embedding_model_name: str, | |
| chunking_strategy: str, | |
| dataset_data: List[Dict], | |
| chunk_size: int = 512, | |
| overlap: int = 50 | |
| ): | |
| """Load a dataset into a new collection with chunking. | |
| Args: | |
| collection_name: Name for the new collection | |
| embedding_model_name: Embedding model to use | |
| chunking_strategy: Chunking strategy to use | |
| dataset_data: List of dataset samples | |
| chunk_size: Size of chunks | |
| overlap: Overlap between chunks | |
| """ | |
| # Store metadata for later evaluation reference | |
| self.chunking_strategy = chunking_strategy | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = overlap | |
| # Create collection | |
| self.create_collection( | |
| collection_name, | |
| embedding_model_name, | |
| metadata={ | |
| "chunking_strategy": chunking_strategy, | |
| "chunk_size": chunk_size, | |
| "overlap": overlap | |
| } | |
| ) | |
| # Get chunker | |
| chunker = ChunkingFactory.create_chunker(chunking_strategy) | |
| # Process documents | |
| all_chunks = [] | |
| all_metadatas = [] | |
| print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") | |
| for idx, sample in enumerate(dataset_data): | |
| # Use 'documents' list if available, otherwise fall back to 'context' | |
| documents = sample.get("documents", []) | |
| # If documents is empty, use context as fallback | |
| if not documents: | |
| context = sample.get("context", "") | |
| if context: | |
| documents = [context] | |
| if not documents: | |
| continue | |
| # Process each document separately for better granularity | |
| for doc_idx, document in enumerate(documents): | |
| if not document or not str(document).strip(): | |
| continue | |
| # Chunk each document | |
| chunks = chunker.chunk_text(str(document), chunk_size, overlap) | |
| # Create metadata for each chunk | |
| for chunk_idx, chunk in enumerate(chunks): | |
| all_chunks.append(chunk) | |
| all_metadatas.append({ | |
| "doc_id": idx, | |
| "doc_idx": doc_idx, # Track which document within the sample | |
| "chunk_id": chunk_idx, | |
| "question": sample.get("question", ""), | |
| "answer": sample.get("answer", ""), | |
| "dataset": sample.get("dataset", ""), | |
| "total_docs": len(documents) | |
| }) | |
| # Add all chunks to collection | |
| self.add_documents(all_chunks, all_metadatas) | |
| print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") | |
| def query( | |
| self, | |
| query_text: str, | |
| n_results: int = 5, | |
| filter_metadata: Optional[Dict] = None | |
| ) -> Dict: | |
| """Query the collection. | |
| Args: | |
| query_text: Query text | |
| n_results: Number of results to return | |
| filter_metadata: Metadata filter | |
| Returns: | |
| Query results | |
| """ | |
| if not self.current_collection: | |
| raise ValueError("No collection selected.") | |
| if not self.embedding_model: | |
| raise ValueError("No embedding model loaded.") | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.embed_query(query_text) | |
| # Query collection with retry logic | |
| try: | |
| results = self.current_collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=n_results, | |
| where=filter_metadata | |
| ) | |
| except Exception as e: | |
| if "default_tenant" in str(e): | |
| print("Warning: Lost connection to ChromaDB, reconnecting...") | |
| self.reconnect() | |
| # Try again after reconnecting | |
| results = self.current_collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=n_results, | |
| where=filter_metadata | |
| ) | |
| else: | |
| raise | |
| return results | |
| def get_retrieved_documents( | |
| self, | |
| query_text: str, | |
| n_results: int = 5 | |
| ) -> List[Dict]: | |
| """Get retrieved documents with metadata. | |
| Args: | |
| query_text: Query text | |
| n_results: Number of results | |
| Returns: | |
| List of retrieved documents with metadata | |
| """ | |
| results = self.query(query_text, n_results) | |
| retrieved_docs = [] | |
| for i in range(len(results['documents'][0])): | |
| retrieved_docs.append({ | |
| "document": results['documents'][0][i], | |
| "metadata": results['metadatas'][0][i], | |
| "distance": results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return retrieved_docs | |
| def delete_collection(self, collection_name: str): | |
| """Delete a collection. | |
| Args: | |
| collection_name: Name of collection to delete | |
| """ | |
| try: | |
| self.client.delete_collection(collection_name) | |
| print(f"Deleted collection: {collection_name}") | |
| except Exception as e: | |
| print(f"Error deleting collection: {str(e)}") | |
| def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: | |
| """Get statistics for a collection. | |
| Args: | |
| collection_name: Name of collection (uses current if None) | |
| Returns: | |
| Dictionary with collection statistics | |
| """ | |
| if collection_name: | |
| collection = self.client.get_collection(collection_name) | |
| elif self.current_collection: | |
| collection = self.current_collection | |
| else: | |
| raise ValueError("No collection specified or selected") | |
| count = collection.count() | |
| metadata = collection.metadata | |
| return { | |
| "name": collection.name, | |
| "count": count, | |
| "metadata": metadata | |
| } | |
| class QdrantManager: | |
| """Manager for Qdrant Cloud operations - persistent storage for HuggingFace Spaces.""" | |
| def __init__(self, url: str = None, api_key: str = None): | |
| """Initialize Qdrant client. | |
| Args: | |
| url: Qdrant Cloud URL (e.g., https://xxx.qdrant.io) | |
| api_key: Qdrant API key | |
| """ | |
| if not QDRANT_AVAILABLE: | |
| raise ImportError("qdrant-client is not installed. Run: pip install qdrant-client") | |
| self.url = url or os.environ.get("QDRANT_URL", "") | |
| self.api_key = api_key or os.environ.get("QDRANT_API_KEY", "") | |
| if not self.url or not self.api_key: | |
| raise ValueError("QDRANT_URL and QDRANT_API_KEY are required") | |
| self.client = QdrantClient( | |
| url=self.url, | |
| api_key=self.api_key, | |
| timeout=60 | |
| ) | |
| self.embedding_model = None | |
| self.current_collection = None | |
| self.vector_size = None | |
| self.chunking_strategy = None | |
| self.chunk_size = None | |
| self.chunk_overlap = None | |
| print(f"[QDRANT] Connected to Qdrant Cloud at {self.url}") | |
| def create_collection( | |
| self, | |
| collection_name: str, | |
| embedding_model_name: str, | |
| metadata: Optional[Dict] = None | |
| ): | |
| """Create a new collection in Qdrant. | |
| Args: | |
| collection_name: Name of the collection | |
| embedding_model_name: Name of the embedding model | |
| metadata: Additional metadata for the collection | |
| """ | |
| # Create embedding model to get vector size | |
| self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) | |
| self.embedding_model.load_model() | |
| # Get vector size from a sample embedding | |
| sample_embedding = self.embedding_model.embed_query("test") | |
| self.vector_size = len(sample_embedding) | |
| # Delete if exists | |
| try: | |
| self.client.delete_collection(collection_name) | |
| print(f"[QDRANT] Deleted existing collection: {collection_name}") | |
| except: | |
| pass | |
| # Create collection | |
| self.client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams( | |
| size=self.vector_size, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| self.current_collection = collection_name | |
| print(f"[QDRANT] Created collection: {collection_name} (vector_size={self.vector_size})") | |
| return self.current_collection | |
| def get_collection(self, collection_name: str): | |
| """Get an existing collection. | |
| Args: | |
| collection_name: Name of the collection | |
| """ | |
| # Verify collection exists | |
| collections = self.list_collections() | |
| if collection_name not in collections: | |
| raise ValueError(f"Collection '{collection_name}' not found") | |
| self.current_collection = collection_name | |
| # Get collection info to determine embedding model | |
| info = self.client.get_collection(collection_name) | |
| self.vector_size = info.config.params.vectors.size | |
| # Try to load embedding model from first document's metadata | |
| embedding_model_name = None | |
| try: | |
| # Scroll to get first point | |
| points, _ = self.client.scroll( | |
| collection_name=collection_name, | |
| limit=1, | |
| with_payload=True | |
| ) | |
| if points and len(points) > 0: | |
| payload = points[0].payload | |
| embedding_model_name = payload.get("embedding_model") | |
| if "chunking_strategy" in payload: | |
| self.chunking_strategy = payload["chunking_strategy"] | |
| except Exception as e: | |
| print(f"[QDRANT] Warning: Could not retrieve metadata: {e}") | |
| # If not found in metadata, try to infer from collection name | |
| if not embedding_model_name: | |
| # Collection name format: dataset_strategy_modelname | |
| # Try common embedding models | |
| known_models = [ | |
| "all-mpnet-base-v2", | |
| "all-MiniLM-L6-v2", | |
| "paraphrase-MiniLM-L6-v2", | |
| "multi-qa-MiniLM-L6-cos-v1" | |
| ] | |
| for model in known_models: | |
| if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""): | |
| embedding_model_name = f"sentence-transformers/{model}" | |
| break | |
| # Default fallback | |
| if not embedding_model_name: | |
| embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | |
| print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}") | |
| # Load the embedding model | |
| if embedding_model_name: | |
| self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) | |
| self.embedding_model.load_model() | |
| print(f"[QDRANT] Loaded embedding model: {embedding_model_name}") | |
| print(f"[QDRANT] Loaded collection: {collection_name}") | |
| return self.current_collection | |
| def list_collections(self) -> List[str]: | |
| """List all collections. | |
| Returns: | |
| List of collection names | |
| """ | |
| collections = self.client.get_collections() | |
| return [col.name for col in collections.collections] | |
| def add_documents( | |
| self, | |
| documents: List[str], | |
| metadatas: Optional[List[Dict]] = None, | |
| ids: Optional[List[str]] = None, | |
| batch_size: int = 100 | |
| ): | |
| """Add documents to the current collection. | |
| Args: | |
| documents: List of document texts | |
| metadatas: List of metadata dictionaries | |
| ids: List of document IDs | |
| batch_size: Batch size for processing | |
| """ | |
| if not self.current_collection: | |
| raise ValueError("No collection selected. Create or get a collection first.") | |
| if not self.embedding_model: | |
| raise ValueError("No embedding model loaded.") | |
| # Generate IDs if not provided | |
| if ids is None: | |
| ids = [str(uuid.uuid4()) for _ in documents] | |
| # Generate default metadata if not provided | |
| if metadatas is None: | |
| metadatas = [{"index": i} for i in range(len(documents))] | |
| # Process in batches | |
| total_docs = len(documents) | |
| print(f"[QDRANT] Adding {total_docs} documents to collection...") | |
| for i in range(0, total_docs, batch_size): | |
| batch_docs = documents[i:i + batch_size] | |
| batch_ids = ids[i:i + batch_size] | |
| batch_metadatas = metadatas[i:i + batch_size] | |
| # Generate embeddings | |
| embeddings = self.embedding_model.embed_documents(batch_docs) | |
| # Create points | |
| points = [] | |
| for j, (doc, embedding, meta, doc_id) in enumerate(zip(batch_docs, embeddings, batch_metadatas, batch_ids)): | |
| # Store document text in payload | |
| payload = {**meta, "text": doc} | |
| points.append(PointStruct( | |
| id=i + j, # Use integer ID | |
| vector=embedding.tolist(), | |
| payload=payload | |
| )) | |
| # Upsert to collection | |
| self.client.upsert( | |
| collection_name=self.current_collection, | |
| points=points | |
| ) | |
| print(f"[QDRANT] Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") | |
| print(f"[QDRANT] Successfully added {total_docs} documents") | |
| def load_dataset_into_collection( | |
| self, | |
| collection_name: str, | |
| embedding_model_name: str, | |
| chunking_strategy: str, | |
| dataset_data: List[Dict], | |
| chunk_size: int = 512, | |
| overlap: int = 50 | |
| ): | |
| """Load a dataset into a new collection with chunking. | |
| Args: | |
| collection_name: Name for the new collection | |
| embedding_model_name: Embedding model to use | |
| chunking_strategy: Chunking strategy to use | |
| dataset_data: List of dataset samples | |
| chunk_size: Size of chunks | |
| overlap: Overlap between chunks | |
| """ | |
| self.chunking_strategy = chunking_strategy | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = overlap | |
| # Create collection | |
| self.create_collection(collection_name, embedding_model_name) | |
| # Get chunker | |
| chunker = ChunkingFactory.create_chunker(chunking_strategy) | |
| # Process documents | |
| all_chunks = [] | |
| all_metadatas = [] | |
| print(f"[QDRANT] Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") | |
| for idx, sample in enumerate(dataset_data): | |
| documents = sample.get("documents", []) | |
| if not documents: | |
| context = sample.get("context", "") | |
| if context: | |
| documents = [context] | |
| if not documents: | |
| continue | |
| for doc_idx, document in enumerate(documents): | |
| if not document or not str(document).strip(): | |
| continue | |
| chunks = chunker.chunk_text(str(document), chunk_size, overlap) | |
| for chunk_idx, chunk in enumerate(chunks): | |
| all_chunks.append(chunk) | |
| all_metadatas.append({ | |
| "doc_id": idx, | |
| "doc_idx": doc_idx, | |
| "chunk_id": chunk_idx, | |
| "question": sample.get("question", ""), | |
| "answer": sample.get("answer", ""), | |
| "dataset": sample.get("dataset", ""), | |
| "total_docs": len(documents), | |
| "embedding_model": embedding_model_name, | |
| "chunking_strategy": chunking_strategy | |
| }) | |
| # Add all chunks to collection | |
| self.add_documents(all_chunks, all_metadatas) | |
| print(f"[QDRANT] Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") | |
| def query( | |
| self, | |
| query_text: str, | |
| n_results: int = 5, | |
| filter_metadata: Optional[Dict] = None | |
| ) -> Dict: | |
| """Query the collection. | |
| Args: | |
| query_text: Query text | |
| n_results: Number of results to return | |
| filter_metadata: Metadata filter | |
| Returns: | |
| Query results in ChromaDB-compatible format | |
| """ | |
| if not self.current_collection: | |
| raise ValueError("No collection selected.") | |
| if not self.embedding_model: | |
| raise ValueError("No embedding model loaded.") | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.embed_query(query_text) | |
| # Query Qdrant using query_points (newer API) or search (older API) | |
| try: | |
| # Try newer API first (qdrant-client >= 1.7) | |
| from qdrant_client.http.models import QueryRequest | |
| results = self.client.query_points( | |
| collection_name=self.current_collection, | |
| query=query_embedding.tolist(), | |
| limit=n_results, | |
| with_payload=True | |
| ).points | |
| except (AttributeError, ImportError): | |
| # Fallback to older API | |
| results = self.client.search( | |
| collection_name=self.current_collection, | |
| query_vector=query_embedding.tolist(), | |
| limit=n_results | |
| ) | |
| # Convert to ChromaDB-compatible format | |
| documents = [[r.payload.get("text", "") for r in results]] | |
| metadatas = [[{k: v for k, v in r.payload.items() if k != "text"} for r in results]] | |
| distances = [[1 - r.score for r in results]] # Convert similarity to distance | |
| return { | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| "distances": distances | |
| } | |
| def get_retrieved_documents( | |
| self, | |
| query_text: str, | |
| n_results: int = 5 | |
| ) -> List[Dict]: | |
| """Get retrieved documents with metadata. | |
| Args: | |
| query_text: Query text | |
| n_results: Number of results | |
| Returns: | |
| List of retrieved documents with metadata | |
| """ | |
| results = self.query(query_text, n_results) | |
| retrieved_docs = [] | |
| for i in range(len(results['documents'][0])): | |
| retrieved_docs.append({ | |
| "document": results['documents'][0][i], | |
| "metadata": results['metadatas'][0][i], | |
| "distance": results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return retrieved_docs | |
| def delete_collection(self, collection_name: str) -> bool: | |
| """Delete a specific collection. | |
| Args: | |
| collection_name: Name of the collection to delete | |
| Returns: | |
| True if deleted successfully, False otherwise | |
| """ | |
| try: | |
| self.client.delete_collection(collection_name) | |
| if self.current_collection == collection_name: | |
| self.current_collection = None | |
| self.embedding_model = None | |
| print(f"[QDRANT] Deleted collection: {collection_name}") | |
| return True | |
| except Exception as e: | |
| print(f"[QDRANT] Error deleting collection: {e}") | |
| return False | |
| def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: | |
| """Get statistics for a collection. | |
| Args: | |
| collection_name: Name of collection (uses current if None) | |
| Returns: | |
| Dictionary with collection statistics | |
| """ | |
| coll_name = collection_name or self.current_collection | |
| if not coll_name: | |
| raise ValueError("No collection specified or selected") | |
| info = self.client.get_collection(coll_name) | |
| return { | |
| "name": coll_name, | |
| "count": info.points_count, | |
| "vector_size": info.config.params.vectors.size, | |
| "status": info.status | |
| } | |
| def create_vector_store(provider: str = "chroma", **kwargs): | |
| """Factory function to create vector store manager. | |
| Args: | |
| provider: "chroma" or "qdrant" | |
| **kwargs: Provider-specific arguments | |
| Returns: | |
| ChromaDBManager or QdrantManager instance | |
| """ | |
| if provider == "qdrant": | |
| if not QDRANT_AVAILABLE: | |
| raise ImportError("qdrant-client not installed. Run: pip install qdrant-client") | |
| return QdrantManager( | |
| url=kwargs.get("url") or os.environ.get("QDRANT_URL"), | |
| api_key=kwargs.get("api_key") or os.environ.get("QDRANT_API_KEY") | |
| ) | |
| else: | |
| return ChromaDBManager( | |
| persist_directory=kwargs.get("persist_directory", "./chroma_db") | |
| ) | |