CapStoneRAG10 / vector_store.py
Developer
Fix: Update Qdrant query API for qdrant-client 1.7+
bc4016b
"""ChromaDB and Qdrant integration for vector storage and retrieval."""
from typing import List, Dict, Optional, Tuple
import chromadb
from chromadb.config import Settings
import uuid
import os
from embedding_models import EmbeddingFactory, EmbeddingModel
from chunking_strategies import ChunkingFactory
import json
# Qdrant imports (optional - for cloud deployment)
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.models import Distance, VectorParams, PointStruct
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
print("Warning: qdrant-client not installed. Qdrant support disabled.")
class ChromaDBManager:
"""Manager for ChromaDB operations."""
def __init__(self, persist_directory: str = "./chroma_db"):
"""Initialize ChromaDB manager.
Args:
persist_directory: Directory to persist ChromaDB data
"""
self.persist_directory = persist_directory
os.makedirs(persist_directory, exist_ok=True)
# Initialize ChromaDB client with is_persistent=True to use persistent storage
try:
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True # Allow reset if needed
)
)
except Exception as e:
print(f"Warning: Could not create persistent client: {e}")
print("Falling back to regular client...")
self.client = chromadb.Client(Settings(
persist_directory=persist_directory,
anonymized_telemetry=False,
allow_reset=True
))
self.embedding_model = None
self.current_collection = None
# Track evaluation-related metadata for reproducibility
self.chunking_strategy = None
self.chunk_size = None
self.chunk_overlap = None
def reconnect(self):
"""Reconnect to ChromaDB in case of connection loss."""
try:
self.client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
print("✅ Reconnected to ChromaDB")
except Exception as e:
print(f"Error reconnecting: {e}")
def create_collection(
self,
collection_name: str,
embedding_model_name: str,
metadata: Optional[Dict] = None
) -> chromadb.Collection:
"""Create a new collection.
Args:
collection_name: Name of the collection
embedding_model_name: Name of the embedding model
metadata: Additional metadata for the collection
Returns:
ChromaDB collection
"""
# Delete if exists
try:
self.client.delete_collection(collection_name)
except:
pass
# Create embedding model
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
# Create collection with metadata
collection_metadata = {
"embedding_model": embedding_model_name,
"hnsw:space": "cosine"
}
if metadata:
collection_metadata.update(metadata)
self.current_collection = self.client.create_collection(
name=collection_name,
metadata=collection_metadata
)
print(f"Created collection: {collection_name}")
return self.current_collection
def get_collection(self, collection_name: str) -> chromadb.Collection:
"""Get an existing collection.
Args:
collection_name: Name of the collection
Returns:
ChromaDB collection
"""
self.current_collection = self.client.get_collection(collection_name)
# Load embedding model from metadata
metadata = self.current_collection.metadata
if "embedding_model" in metadata:
self.embedding_model = EmbeddingFactory.create_embedding_model(
metadata["embedding_model"]
)
self.embedding_model.load_model()
# Restore chunking metadata for evaluation reproducibility
if "chunking_strategy" in metadata:
self.chunking_strategy = metadata["chunking_strategy"]
if "chunk_size" in metadata:
self.chunk_size = metadata["chunk_size"]
if "overlap" in metadata:
self.chunk_overlap = metadata["overlap"]
return self.current_collection
def list_collections(self) -> List[str]:
"""List all collections.
Returns:
List of collection names
"""
collections = self.client.list_collections()
return [col.name for col in collections]
def clear_all_collections(self) -> int:
"""Delete all collections from the database.
Returns:
Number of collections deleted
"""
collections = self.list_collections()
count = 0
for collection_name in collections:
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
count += 1
except Exception as e:
print(f"Error deleting collection {collection_name}: {e}")
self.current_collection = None
self.embedding_model = None
print(f"✅ Cleared {count} collections")
return count
def delete_collection(self, collection_name: str) -> bool:
"""Delete a specific collection.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted successfully, False otherwise
"""
try:
self.client.delete_collection(collection_name)
if self.current_collection and self.current_collection.name == collection_name:
self.current_collection = None
self.embedding_model = None
print(f"✅ Deleted collection: {collection_name}")
return True
except Exception as e:
print(f"❌ Error deleting collection: {e}")
return False
def add_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100
):
"""Add documents to the current collection.
Args:
documents: List of document texts
metadatas: List of metadata dictionaries
ids: List of document IDs
batch_size: Batch size for processing
"""
if not self.current_collection:
raise ValueError("No collection selected. Create or get a collection first.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate default metadata if not provided
if metadatas is None:
metadatas = [{"index": i} for i in range(len(documents))]
# Process in batches
total_docs = len(documents)
print(f"Adding {total_docs} documents to collection...")
for i in range(0, total_docs, batch_size):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = self.embedding_model.embed_documents(batch_docs)
# Add to collection
self.current_collection.add(
documents=batch_docs,
embeddings=embeddings.tolist(),
metadatas=batch_metadatas,
ids=batch_ids
)
print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
print(f"Successfully added {total_docs} documents")
def load_dataset_into_collection(
self,
collection_name: str,
embedding_model_name: str,
chunking_strategy: str,
dataset_data: List[Dict],
chunk_size: int = 512,
overlap: int = 50
):
"""Load a dataset into a new collection with chunking.
Args:
collection_name: Name for the new collection
embedding_model_name: Embedding model to use
chunking_strategy: Chunking strategy to use
dataset_data: List of dataset samples
chunk_size: Size of chunks
overlap: Overlap between chunks
"""
# Store metadata for later evaluation reference
self.chunking_strategy = chunking_strategy
self.chunk_size = chunk_size
self.chunk_overlap = overlap
# Create collection
self.create_collection(
collection_name,
embedding_model_name,
metadata={
"chunking_strategy": chunking_strategy,
"chunk_size": chunk_size,
"overlap": overlap
}
)
# Get chunker
chunker = ChunkingFactory.create_chunker(chunking_strategy)
# Process documents
all_chunks = []
all_metadatas = []
print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
for idx, sample in enumerate(dataset_data):
# Use 'documents' list if available, otherwise fall back to 'context'
documents = sample.get("documents", [])
# If documents is empty, use context as fallback
if not documents:
context = sample.get("context", "")
if context:
documents = [context]
if not documents:
continue
# Process each document separately for better granularity
for doc_idx, document in enumerate(documents):
if not document or not str(document).strip():
continue
# Chunk each document
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
# Create metadata for each chunk
for chunk_idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append({
"doc_id": idx,
"doc_idx": doc_idx, # Track which document within the sample
"chunk_id": chunk_idx,
"question": sample.get("question", ""),
"answer": sample.get("answer", ""),
"dataset": sample.get("dataset", ""),
"total_docs": len(documents)
})
# Add all chunks to collection
self.add_documents(all_chunks, all_metadatas)
print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
def query(
self,
query_text: str,
n_results: int = 5,
filter_metadata: Optional[Dict] = None
) -> Dict:
"""Query the collection.
Args:
query_text: Query text
n_results: Number of results to return
filter_metadata: Metadata filter
Returns:
Query results
"""
if not self.current_collection:
raise ValueError("No collection selected.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query_text)
# Query collection with retry logic
try:
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
except Exception as e:
if "default_tenant" in str(e):
print("Warning: Lost connection to ChromaDB, reconnecting...")
self.reconnect()
# Try again after reconnecting
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
else:
raise
return results
def get_retrieved_documents(
self,
query_text: str,
n_results: int = 5
) -> List[Dict]:
"""Get retrieved documents with metadata.
Args:
query_text: Query text
n_results: Number of results
Returns:
List of retrieved documents with metadata
"""
results = self.query(query_text, n_results)
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i] if 'distances' in results else None
})
return retrieved_docs
def delete_collection(self, collection_name: str):
"""Delete a collection.
Args:
collection_name: Name of collection to delete
"""
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
except Exception as e:
print(f"Error deleting collection: {str(e)}")
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
"""Get statistics for a collection.
Args:
collection_name: Name of collection (uses current if None)
Returns:
Dictionary with collection statistics
"""
if collection_name:
collection = self.client.get_collection(collection_name)
elif self.current_collection:
collection = self.current_collection
else:
raise ValueError("No collection specified or selected")
count = collection.count()
metadata = collection.metadata
return {
"name": collection.name,
"count": count,
"metadata": metadata
}
class QdrantManager:
"""Manager for Qdrant Cloud operations - persistent storage for HuggingFace Spaces."""
def __init__(self, url: str = None, api_key: str = None):
"""Initialize Qdrant client.
Args:
url: Qdrant Cloud URL (e.g., https://xxx.qdrant.io)
api_key: Qdrant API key
"""
if not QDRANT_AVAILABLE:
raise ImportError("qdrant-client is not installed. Run: pip install qdrant-client")
self.url = url or os.environ.get("QDRANT_URL", "")
self.api_key = api_key or os.environ.get("QDRANT_API_KEY", "")
if not self.url or not self.api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY are required")
self.client = QdrantClient(
url=self.url,
api_key=self.api_key,
timeout=60
)
self.embedding_model = None
self.current_collection = None
self.vector_size = None
self.chunking_strategy = None
self.chunk_size = None
self.chunk_overlap = None
print(f"[QDRANT] Connected to Qdrant Cloud at {self.url}")
def create_collection(
self,
collection_name: str,
embedding_model_name: str,
metadata: Optional[Dict] = None
):
"""Create a new collection in Qdrant.
Args:
collection_name: Name of the collection
embedding_model_name: Name of the embedding model
metadata: Additional metadata for the collection
"""
# Create embedding model to get vector size
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
# Get vector size from a sample embedding
sample_embedding = self.embedding_model.embed_query("test")
self.vector_size = len(sample_embedding)
# Delete if exists
try:
self.client.delete_collection(collection_name)
print(f"[QDRANT] Deleted existing collection: {collection_name}")
except:
pass
# Create collection
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE
)
)
self.current_collection = collection_name
print(f"[QDRANT] Created collection: {collection_name} (vector_size={self.vector_size})")
return self.current_collection
def get_collection(self, collection_name: str):
"""Get an existing collection.
Args:
collection_name: Name of the collection
"""
# Verify collection exists
collections = self.list_collections()
if collection_name not in collections:
raise ValueError(f"Collection '{collection_name}' not found")
self.current_collection = collection_name
# Get collection info to determine embedding model
info = self.client.get_collection(collection_name)
self.vector_size = info.config.params.vectors.size
# Try to load embedding model from first document's metadata
embedding_model_name = None
try:
# Scroll to get first point
points, _ = self.client.scroll(
collection_name=collection_name,
limit=1,
with_payload=True
)
if points and len(points) > 0:
payload = points[0].payload
embedding_model_name = payload.get("embedding_model")
if "chunking_strategy" in payload:
self.chunking_strategy = payload["chunking_strategy"]
except Exception as e:
print(f"[QDRANT] Warning: Could not retrieve metadata: {e}")
# If not found in metadata, try to infer from collection name
if not embedding_model_name:
# Collection name format: dataset_strategy_modelname
# Try common embedding models
known_models = [
"all-mpnet-base-v2",
"all-MiniLM-L6-v2",
"paraphrase-MiniLM-L6-v2",
"multi-qa-MiniLM-L6-cos-v1"
]
for model in known_models:
if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""):
embedding_model_name = f"sentence-transformers/{model}"
break
# Default fallback
if not embedding_model_name:
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}")
# Load the embedding model
if embedding_model_name:
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
print(f"[QDRANT] Loaded embedding model: {embedding_model_name}")
print(f"[QDRANT] Loaded collection: {collection_name}")
return self.current_collection
def list_collections(self) -> List[str]:
"""List all collections.
Returns:
List of collection names
"""
collections = self.client.get_collections()
return [col.name for col in collections.collections]
def add_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100
):
"""Add documents to the current collection.
Args:
documents: List of document texts
metadatas: List of metadata dictionaries
ids: List of document IDs
batch_size: Batch size for processing
"""
if not self.current_collection:
raise ValueError("No collection selected. Create or get a collection first.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate default metadata if not provided
if metadatas is None:
metadatas = [{"index": i} for i in range(len(documents))]
# Process in batches
total_docs = len(documents)
print(f"[QDRANT] Adding {total_docs} documents to collection...")
for i in range(0, total_docs, batch_size):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = self.embedding_model.embed_documents(batch_docs)
# Create points
points = []
for j, (doc, embedding, meta, doc_id) in enumerate(zip(batch_docs, embeddings, batch_metadatas, batch_ids)):
# Store document text in payload
payload = {**meta, "text": doc}
points.append(PointStruct(
id=i + j, # Use integer ID
vector=embedding.tolist(),
payload=payload
))
# Upsert to collection
self.client.upsert(
collection_name=self.current_collection,
points=points
)
print(f"[QDRANT] Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
print(f"[QDRANT] Successfully added {total_docs} documents")
def load_dataset_into_collection(
self,
collection_name: str,
embedding_model_name: str,
chunking_strategy: str,
dataset_data: List[Dict],
chunk_size: int = 512,
overlap: int = 50
):
"""Load a dataset into a new collection with chunking.
Args:
collection_name: Name for the new collection
embedding_model_name: Embedding model to use
chunking_strategy: Chunking strategy to use
dataset_data: List of dataset samples
chunk_size: Size of chunks
overlap: Overlap between chunks
"""
self.chunking_strategy = chunking_strategy
self.chunk_size = chunk_size
self.chunk_overlap = overlap
# Create collection
self.create_collection(collection_name, embedding_model_name)
# Get chunker
chunker = ChunkingFactory.create_chunker(chunking_strategy)
# Process documents
all_chunks = []
all_metadatas = []
print(f"[QDRANT] Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
for idx, sample in enumerate(dataset_data):
documents = sample.get("documents", [])
if not documents:
context = sample.get("context", "")
if context:
documents = [context]
if not documents:
continue
for doc_idx, document in enumerate(documents):
if not document or not str(document).strip():
continue
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
for chunk_idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append({
"doc_id": idx,
"doc_idx": doc_idx,
"chunk_id": chunk_idx,
"question": sample.get("question", ""),
"answer": sample.get("answer", ""),
"dataset": sample.get("dataset", ""),
"total_docs": len(documents),
"embedding_model": embedding_model_name,
"chunking_strategy": chunking_strategy
})
# Add all chunks to collection
self.add_documents(all_chunks, all_metadatas)
print(f"[QDRANT] Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
def query(
self,
query_text: str,
n_results: int = 5,
filter_metadata: Optional[Dict] = None
) -> Dict:
"""Query the collection.
Args:
query_text: Query text
n_results: Number of results to return
filter_metadata: Metadata filter
Returns:
Query results in ChromaDB-compatible format
"""
if not self.current_collection:
raise ValueError("No collection selected.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query_text)
# Query Qdrant using query_points (newer API) or search (older API)
try:
# Try newer API first (qdrant-client >= 1.7)
from qdrant_client.http.models import QueryRequest
results = self.client.query_points(
collection_name=self.current_collection,
query=query_embedding.tolist(),
limit=n_results,
with_payload=True
).points
except (AttributeError, ImportError):
# Fallback to older API
results = self.client.search(
collection_name=self.current_collection,
query_vector=query_embedding.tolist(),
limit=n_results
)
# Convert to ChromaDB-compatible format
documents = [[r.payload.get("text", "") for r in results]]
metadatas = [[{k: v for k, v in r.payload.items() if k != "text"} for r in results]]
distances = [[1 - r.score for r in results]] # Convert similarity to distance
return {
"documents": documents,
"metadatas": metadatas,
"distances": distances
}
def get_retrieved_documents(
self,
query_text: str,
n_results: int = 5
) -> List[Dict]:
"""Get retrieved documents with metadata.
Args:
query_text: Query text
n_results: Number of results
Returns:
List of retrieved documents with metadata
"""
results = self.query(query_text, n_results)
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i] if 'distances' in results else None
})
return retrieved_docs
def delete_collection(self, collection_name: str) -> bool:
"""Delete a specific collection.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted successfully, False otherwise
"""
try:
self.client.delete_collection(collection_name)
if self.current_collection == collection_name:
self.current_collection = None
self.embedding_model = None
print(f"[QDRANT] Deleted collection: {collection_name}")
return True
except Exception as e:
print(f"[QDRANT] Error deleting collection: {e}")
return False
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
"""Get statistics for a collection.
Args:
collection_name: Name of collection (uses current if None)
Returns:
Dictionary with collection statistics
"""
coll_name = collection_name or self.current_collection
if not coll_name:
raise ValueError("No collection specified or selected")
info = self.client.get_collection(coll_name)
return {
"name": coll_name,
"count": info.points_count,
"vector_size": info.config.params.vectors.size,
"status": info.status
}
def create_vector_store(provider: str = "chroma", **kwargs):
"""Factory function to create vector store manager.
Args:
provider: "chroma" or "qdrant"
**kwargs: Provider-specific arguments
Returns:
ChromaDBManager or QdrantManager instance
"""
if provider == "qdrant":
if not QDRANT_AVAILABLE:
raise ImportError("qdrant-client not installed. Run: pip install qdrant-client")
return QdrantManager(
url=kwargs.get("url") or os.environ.get("QDRANT_URL"),
api_key=kwargs.get("api_key") or os.environ.get("QDRANT_API_KEY")
)
else:
return ChromaDBManager(
persist_directory=kwargs.get("persist_directory", "./chroma_db")
)