CapstoneRAG10 / vector_store.py
PavaniYerra's picture
Clone
9bc547e
"""ChromaDB and Qdrant integration for vector storage and retrieval."""
from typing import List, Dict, Optional, Tuple
import chromadb
from chromadb.config import Settings
import uuid
import os
from embedding_models import EmbeddingFactory, EmbeddingModel
from chunking_strategies import ChunkingFactory
import json
# Qdrant imports (optional - for cloud deployment)
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.models import Distance, VectorParams, PointStruct
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
print("Warning: qdrant-client not installed. Qdrant support disabled.")
class ChromaDBManager:
"""Manager for ChromaDB operations."""
def __init__(self, persist_directory: str = "./chroma_db"):
"""Initialize ChromaDB manager.
Args:
persist_directory: Directory to persist ChromaDB data
"""
self.persist_directory = persist_directory
os.makedirs(persist_directory, exist_ok=True)
# Initialize ChromaDB client with is_persistent=True to use persistent storage
try:
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True # Allow reset if needed
)
)
except Exception as e:
print(f"Warning: Could not create persistent client: {e}")
print("Falling back to regular client...")
self.client = chromadb.Client(Settings(
persist_directory=persist_directory,
anonymized_telemetry=False,
allow_reset=True
))
self.embedding_model = None
self.current_collection = None
# Track evaluation-related metadata for reproducibility
self.chunking_strategy = None
self.chunk_size = None
self.chunk_overlap = None
def reconnect(self):
"""Reconnect to ChromaDB in case of connection loss."""
try:
self.client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
print("✅ Reconnected to ChromaDB")
except Exception as e:
print(f"Error reconnecting: {e}")
def create_collection(
self,
collection_name: str,
embedding_model_name: str,
metadata: Optional[Dict] = None
) -> chromadb.Collection:
"""Create a new collection.
Args:
collection_name: Name of the collection
embedding_model_name: Name of the embedding model
metadata: Additional metadata for the collection
Returns:
ChromaDB collection
"""
# Delete if exists
try:
self.client.delete_collection(collection_name)
except:
pass
# Create embedding model
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
# Create collection with metadata
collection_metadata = {
"embedding_model": embedding_model_name,
"hnsw:space": "cosine"
}
if metadata:
collection_metadata.update(metadata)
self.current_collection = self.client.create_collection(
name=collection_name,
metadata=collection_metadata
)
print(f"Created collection: {collection_name}")
return self.current_collection
def get_collection(self, collection_name: str) -> chromadb.Collection:
"""Get an existing collection.
Args:
collection_name: Name of the collection
Returns:
ChromaDB collection
"""
self.current_collection = self.client.get_collection(collection_name)
# Load embedding model from metadata
metadata = self.current_collection.metadata
if "embedding_model" in metadata:
self.embedding_model = EmbeddingFactory.create_embedding_model(
metadata["embedding_model"]
)
self.embedding_model.load_model()
# Restore chunking metadata for evaluation reproducibility
if "chunking_strategy" in metadata:
self.chunking_strategy = metadata["chunking_strategy"]
if "chunk_size" in metadata:
self.chunk_size = metadata["chunk_size"]
if "overlap" in metadata:
self.chunk_overlap = metadata["overlap"]
return self.current_collection
def list_collections(self) -> List[str]:
"""List all collections.
Returns:
List of collection names
"""
collections = self.client.list_collections()
return [col.name for col in collections]
def clear_all_collections(self) -> int:
"""Delete all collections from the database.
Returns:
Number of collections deleted
"""
collections = self.list_collections()
count = 0
for collection_name in collections:
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
count += 1
except Exception as e:
print(f"Error deleting collection {collection_name}: {e}")
self.current_collection = None
self.embedding_model = None
print(f"✅ Cleared {count} collections")
return count
def delete_collection(self, collection_name: str) -> bool:
"""Delete a specific collection.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted successfully, False otherwise
"""
try:
self.client.delete_collection(collection_name)
if self.current_collection and self.current_collection.name == collection_name:
self.current_collection = None
self.embedding_model = None
print(f"✅ Deleted collection: {collection_name}")
return True
except Exception as e:
print(f"❌ Error deleting collection: {e}")
return False
def add_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100
):
"""Add documents to the current collection.
Args:
documents: List of document texts
metadatas: List of metadata dictionaries
ids: List of document IDs
batch_size: Batch size for processing
"""
if not self.current_collection:
raise ValueError("No collection selected. Create or get a collection first.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate default metadata if not provided
if metadatas is None:
metadatas = [{"index": i} for i in range(len(documents))]
# Process in batches
total_docs = len(documents)
print(f"Adding {total_docs} documents to collection...")
for i in range(0, total_docs, batch_size):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = self.embedding_model.embed_documents(batch_docs)
# Add to collection
self.current_collection.add(
documents=batch_docs,
embeddings=embeddings.tolist(),
metadatas=batch_metadatas,
ids=batch_ids
)
print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
print(f"Successfully added {total_docs} documents")
def load_dataset_into_collection(
self,
collection_name: str,
embedding_model_name: str,
chunking_strategy: str,
dataset_data: List[Dict],
chunk_size: int = 512,
overlap: int = 50
):
"""Load a dataset into a new collection with chunking.
Args:
collection_name: Name for the new collection
embedding_model_name: Embedding model to use
chunking_strategy: Chunking strategy to use
dataset_data: List of dataset samples
chunk_size: Size of chunks
overlap: Overlap between chunks
"""
# Store metadata for later evaluation reference
self.chunking_strategy = chunking_strategy
self.chunk_size = chunk_size
self.chunk_overlap = overlap
# Create collection
self.create_collection(
collection_name,
embedding_model_name,
metadata={
"chunking_strategy": chunking_strategy,
"chunk_size": chunk_size,
"overlap": overlap
}
)
# Get chunker
chunker = ChunkingFactory.create_chunker(chunking_strategy)
# Process documents
all_chunks = []
all_metadatas = []
print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
for idx, sample in enumerate(dataset_data):
# Use 'documents' list if available, otherwise fall back to 'context'
documents = sample.get("documents", [])
# If documents is empty, use context as fallback
if not documents:
context = sample.get("context", "")
if context:
documents = [context]
if not documents:
continue
# Process each document separately for better granularity
for doc_idx, document in enumerate(documents):
if not document or not str(document).strip():
continue
# Chunk each document
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
# Create metadata for each chunk
for chunk_idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append({
"doc_id": idx,
"doc_idx": doc_idx, # Track which document within the sample
"chunk_id": chunk_idx,
"question": sample.get("question", ""),
"answer": sample.get("answer", ""),
"dataset": sample.get("dataset", ""),
"total_docs": len(documents)
})
# Add all chunks to collection
self.add_documents(all_chunks, all_metadatas)
print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
def query(
self,
query_text: str,
n_results: int = 5,
filter_metadata: Optional[Dict] = None
) -> Dict:
"""Query the collection.
Args:
query_text: Query text
n_results: Number of results to return
filter_metadata: Metadata filter
Returns:
Query results
"""
if not self.current_collection:
raise ValueError("No collection selected.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query_text)
# Query collection with retry logic
try:
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
except Exception as e:
if "default_tenant" in str(e):
print("Warning: Lost connection to ChromaDB, reconnecting...")
self.reconnect()
# Try again after reconnecting
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
else:
raise
return results
def get_retrieved_documents(
self,
query_text: str,
n_results: int = 5
) -> List[Dict]:
"""Get retrieved documents with metadata.
Args:
query_text: Query text
n_results: Number of results
Returns:
List of retrieved documents with metadata
"""
results = self.query(query_text, n_results)
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i] if 'distances' in results else None
})
return retrieved_docs
def delete_collection(self, collection_name: str):
"""Delete a collection.
Args:
collection_name: Name of collection to delete
"""
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
except Exception as e:
print(f"Error deleting collection: {str(e)}")
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
"""Get statistics for a collection.
Args:
collection_name: Name of collection (uses current if None)
Returns:
Dictionary with collection statistics
"""
if collection_name:
collection = self.client.get_collection(collection_name)
elif self.current_collection:
collection = self.current_collection
else:
raise ValueError("No collection specified or selected")
count = collection.count()
metadata = collection.metadata
return {
"name": collection.name,
"count": count,
"metadata": metadata
}
class QdrantManager:
"""Manager for Qdrant Cloud operations - persistent storage for HuggingFace Spaces."""
def __init__(self, url: str = None, api_key: str = None):
"""Initialize Qdrant client.
Args:
url: Qdrant Cloud URL (e.g., https://xxx.qdrant.io)
api_key: Qdrant API key
"""
if not QDRANT_AVAILABLE:
raise ImportError("qdrant-client is not installed. Run: pip install qdrant-client")
self.url = url or os.environ.get("QDRANT_URL", "")
self.api_key = api_key or os.environ.get("QDRANT_API_KEY", "")
if not self.url or not self.api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY are required")
self.client = QdrantClient(
url=self.url,
api_key=self.api_key,
timeout=60
)
self.embedding_model = None
self.current_collection = None
self.vector_size = None
self.chunking_strategy = None
self.chunk_size = None
self.chunk_overlap = None
print(f"[QDRANT] Connected to Qdrant Cloud at {self.url}")
def create_collection(
self,
collection_name: str,
embedding_model_name: str,
metadata: Optional[Dict] = None
):
"""Create a new collection in Qdrant.
Args:
collection_name: Name of the collection
embedding_model_name: Name of the embedding model
metadata: Additional metadata for the collection
"""
# Create embedding model to get vector size
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
# Get vector size from a sample embedding
sample_embedding = self.embedding_model.embed_query("test")
self.vector_size = len(sample_embedding)
# Delete if exists
try:
self.client.delete_collection(collection_name)
print(f"[QDRANT] Deleted existing collection: {collection_name}")
except:
pass
# Create collection
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
self.current_collection = collection_name
print(f"[QDRANT] Created collection: {collection_name} (vector_size={self.vector_size})")
return self.current_collection
def get_collection(self, collection_name: str):
"""Get an existing collection.
Args:
collection_name: Name of the collection
"""
# Verify collection exists
collections = self.list_collections()
if collection_name not in collections:
raise ValueError(f"Collection '{collection_name}' not found")
self.current_collection = collection_name
# Get collection info to determine embedding model
info = self.client.get_collection(collection_name)
self.vector_size = info.config.params.vectors.size
# Try to load embedding model from first document's metadata
embedding_model_name = None
try:
# Scroll to get first point
points, _ = self.client.scroll(
collection_name=collection_name,
limit=1,
with_payload=True
)
if points and len(points) > 0:
payload = points[0].payload
embedding_model_name = payload.get("embedding_model")
if "chunking_strategy" in payload:
self.chunking_strategy = payload["chunking_strategy"]
except Exception as e:
print(f"[QDRANT] Warning: Could not retrieve metadata: {e}")
# If not found in metadata, try to infer from collection name
if not embedding_model_name:
# Collection name format: dataset_strategy_modelname
# Try common embedding models
known_models = [
"all-mpnet-base-v2",
"all-MiniLM-L6-v2",
"paraphrase-MiniLM-L6-v2",
"multi-qa-MiniLM-L6-cos-v1"
]
for model in known_models:
if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""):
embedding_model_name = f"sentence-transformers/{model}"
break
# Default fallback
if not embedding_model_name:
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}")
# Load the embedding model
if embedding_model_name:
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
print(f"[QDRANT] Loaded embedding model: {embedding_model_name}")
print(f"[QDRANT] Loaded collection: {collection_name}")
return self.current_collection
def list_collections(self) -> List[str]:
"""List all collections.
Returns:
List of collection names
"""
collections = self.client.get_collections()
return [col.name for col in collections.collections]
def add_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100
):
"""Add documents to the current collection.
Args:
documents: List of document texts
metadatas: List of metadata dictionaries
ids: List of document IDs
batch_size: Batch size for processing
"""
if not self.current_collection:
raise ValueError("No collection selected. Create or get a collection first.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate default metadata if not provided
if metadatas is None:
metadatas = [{"index": i} for i in range(len(documents))]
# Process in batches
total_docs = len(documents)
print(f"[QDRANT] Adding {total_docs} documents to collection...")
for i in range(0, total_docs, batch_size):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = self.embedding_model.embed_documents(batch_docs)
# Create points
points = []
for j, (doc, embedding, meta, doc_id) in enumerate(zip(batch_docs, embeddings, batch_metadatas, batch_ids)):
# Store document text in payload
payload = {**meta, "text": doc}
points.append(PointStruct(
id=i + j, # Use integer ID
vector=embedding.tolist(),
payload=payload
))
# Upsert to collection
self.client.upsert(
collection_name=self.current_collection,
points=points
)
print(f"[QDRANT] Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
print(f"[QDRANT] Successfully added {total_docs} documents")
def load_dataset_into_collection(
self,
collection_name: str,
embedding_model_name: str,
chunking_strategy: str,
dataset_data: List[Dict],
chunk_size: int = 512,
overlap: int = 50
):
"""Load a dataset into a new collection with chunking.
Args:
collection_name: Name for the new collection
embedding_model_name: Embedding model to use
chunking_strategy: Chunking strategy to use
dataset_data: List of dataset samples
chunk_size: Size of chunks
overlap: Overlap between chunks
"""
self.chunking_strategy = chunking_strategy
self.chunk_size = chunk_size
self.chunk_overlap = overlap
# Create collection
self.create_collection(collection_name, embedding_model_name)
# Get chunker
chunker = ChunkingFactory.create_chunker(chunking_strategy)
# Process documents
all_chunks = []
all_metadatas = []
print(f"[QDRANT] Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
for idx, sample in enumerate(dataset_data):
documents = sample.get("documents", [])
if not documents:
context = sample.get("context", "")
if context:
documents = [context]
if not documents:
continue
for doc_idx, document in enumerate(documents):
if not document or not str(document).strip():
continue
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
for chunk_idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append({
"doc_id": idx,
"doc_idx": doc_idx,
"chunk_id": chunk_idx,
"question": sample.get("question", ""),
"answer": sample.get("answer", ""),
"dataset": sample.get("dataset", ""),
"total_docs": len(documents),
"embedding_model": embedding_model_name,
"chunking_strategy": chunking_strategy
})
# Add all chunks to collection
self.add_documents(all_chunks, all_metadatas)
print(f"[QDRANT] Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
def query(
self,
query_text: str,
n_results: int = 5,
filter_metadata: Optional[Dict] = None
) -> Dict:
"""Query the collection.
Args:
query_text: Query text
n_results: Number of results to return
filter_metadata: Metadata filter
Returns:
Query results in ChromaDB-compatible format
"""
if not self.current_collection:
raise ValueError("No collection selected.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query_text)
# Query Qdrant using query_points (newer API) or search (older API)
try:
# Try newer API first (qdrant-client >= 1.7)
from qdrant_client.http.models import QueryRequest
results = self.client.query_points(
collection_name=self.current_collection,
query=query_embedding.tolist(),
limit=n_results,
with_payload=True
).points
except (AttributeError, ImportError):
# Fallback to older API
results = self.client.search(
collection_name=self.current_collection,
query_vector=query_embedding.tolist(),
limit=n_results
)
# Convert to ChromaDB-compatible format
documents = [[r.payload.get("text", "") for r in results]]
metadatas = [[{k: v for k, v in r.payload.items() if k != "text"} for r in results]]
distances = [[1 - r.score for r in results]] # Convert similarity to distance
return {
"documents": documents,
"metadatas": metadatas,
"distances": distances
}
def get_retrieved_documents(
self,
query_text: str,
n_results: int = 5
) -> List[Dict]:
"""Get retrieved documents with metadata.
Args:
query_text: Query text
n_results: Number of results
Returns:
List of retrieved documents with metadata
"""
results = self.query(query_text, n_results)
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i] if 'distances' in results else None
})
return retrieved_docs
def delete_collection(self, collection_name: str) -> bool:
"""Delete a specific collection.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted successfully, False otherwise
"""
try:
self.client.delete_collection(collection_name)
if self.current_collection == collection_name:
self.current_collection = None
self.embedding_model = None
print(f"[QDRANT] Deleted collection: {collection_name}")
return True
except Exception as e:
print(f"[QDRANT] Error deleting collection: {e}")
return False
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
"""Get statistics for a collection.
Args:
collection_name: Name of collection (uses current if None)
Returns:
Dictionary with collection statistics
"""
coll_name = collection_name or self.current_collection
if not coll_name:
raise ValueError("No collection specified or selected")
info = self.client.get_collection(coll_name)
return {
"name": coll_name,
"count": info.points_count,
"vector_size": info.config.params.vectors.size,
"status": info.status
}
def create_vector_store(provider: str = "chroma", **kwargs):
"""Factory function to create vector store manager.
Args:
provider: "chroma" or "qdrant"
**kwargs: Provider-specific arguments
Returns:
ChromaDBManager or QdrantManager instance
"""
if provider == "qdrant":
if not QDRANT_AVAILABLE:
raise ImportError("qdrant-client not installed. Run: pip install qdrant-client")
return QdrantManager(
url=kwargs.get("url") or os.environ.get("QDRANT_URL"),
api_key=kwargs.get("api_key") or os.environ.get("QDRANT_API_KEY")
)
else:
return ChromaDBManager(
persist_directory=kwargs.get("persist_directory", "./chroma_db")
)