RAG10 / vector_store.py
Vivek Kadamati
Initial commit
ee444c0
"""ChromaDB integration for vector storage and retrieval."""
from typing import List, Dict, Optional, Tuple
import chromadb
from chromadb.config import Settings
import uuid
import os
from embedding_models import EmbeddingFactory, EmbeddingModel
from chunking_strategies import ChunkingFactory
import json
class ChromaDBManager:
"""Manager for ChromaDB operations."""
def __init__(self, persist_directory: str = "./chroma_db"):
"""Initialize ChromaDB manager.
Args:
persist_directory: Directory to persist ChromaDB data
"""
self.persist_directory = persist_directory
os.makedirs(persist_directory, exist_ok=True)
# Initialize ChromaDB client with is_persistent=True to use persistent storage
try:
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True # Allow reset if needed
)
)
except Exception as e:
print(f"Warning: Could not create persistent client: {e}")
print("Falling back to regular client...")
self.client = chromadb.Client(Settings(
persist_directory=persist_directory,
anonymized_telemetry=False,
allow_reset=True
))
self.embedding_model = None
self.current_collection = None
def reconnect(self):
"""Reconnect to ChromaDB in case of connection loss."""
try:
self.client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
print("✅ Reconnected to ChromaDB")
except Exception as e:
print(f"Error reconnecting: {e}")
def create_collection(
self,
collection_name: str,
embedding_model_name: str,
metadata: Optional[Dict] = None
) -> chromadb.Collection:
"""Create a new collection.
Args:
collection_name: Name of the collection
embedding_model_name: Name of the embedding model
metadata: Additional metadata for the collection
Returns:
ChromaDB collection
"""
# Delete if exists
try:
self.client.delete_collection(collection_name)
except:
pass
# Create embedding model
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
self.embedding_model.load_model()
# Create collection with metadata
collection_metadata = {
"embedding_model": embedding_model_name,
"hnsw:space": "cosine"
}
if metadata:
collection_metadata.update(metadata)
self.current_collection = self.client.create_collection(
name=collection_name,
metadata=collection_metadata
)
print(f"Created collection: {collection_name}")
return self.current_collection
def get_collection(self, collection_name: str) -> chromadb.Collection:
"""Get an existing collection.
Args:
collection_name: Name of the collection
Returns:
ChromaDB collection
"""
self.current_collection = self.client.get_collection(collection_name)
# Load embedding model from metadata
metadata = self.current_collection.metadata
if "embedding_model" in metadata:
self.embedding_model = EmbeddingFactory.create_embedding_model(
metadata["embedding_model"]
)
self.embedding_model.load_model()
return self.current_collection
def list_collections(self) -> List[str]:
"""List all collections.
Returns:
List of collection names
"""
collections = self.client.list_collections()
return [col.name for col in collections]
def clear_all_collections(self) -> int:
"""Delete all collections from the database.
Returns:
Number of collections deleted
"""
collections = self.list_collections()
count = 0
for collection_name in collections:
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
count += 1
except Exception as e:
print(f"Error deleting collection {collection_name}: {e}")
self.current_collection = None
self.embedding_model = None
print(f"✅ Cleared {count} collections")
return count
def delete_collection(self, collection_name: str) -> bool:
"""Delete a specific collection.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted successfully, False otherwise
"""
try:
self.client.delete_collection(collection_name)
if self.current_collection and self.current_collection.name == collection_name:
self.current_collection = None
self.embedding_model = None
print(f"✅ Deleted collection: {collection_name}")
return True
except Exception as e:
print(f"❌ Error deleting collection: {e}")
return False
def add_documents(
self,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100
):
"""Add documents to the current collection.
Args:
documents: List of document texts
metadatas: List of metadata dictionaries
ids: List of document IDs
batch_size: Batch size for processing
"""
if not self.current_collection:
raise ValueError("No collection selected. Create or get a collection first.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate IDs if not provided
if ids is None:
ids = [str(uuid.uuid4()) for _ in documents]
# Generate default metadata if not provided
if metadatas is None:
metadatas = [{"index": i} for i in range(len(documents))]
# Process in batches
total_docs = len(documents)
print(f"Adding {total_docs} documents to collection...")
for i in range(0, total_docs, batch_size):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = self.embedding_model.embed_documents(batch_docs)
# Add to collection
self.current_collection.add(
documents=batch_docs,
embeddings=embeddings.tolist(),
metadatas=batch_metadatas,
ids=batch_ids
)
print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
print(f"Successfully added {total_docs} documents")
def load_dataset_into_collection(
self,
collection_name: str,
embedding_model_name: str,
chunking_strategy: str,
dataset_data: List[Dict],
chunk_size: int = 512,
overlap: int = 50
):
"""Load a dataset into a new collection with chunking.
Args:
collection_name: Name for the new collection
embedding_model_name: Embedding model to use
chunking_strategy: Chunking strategy to use
dataset_data: List of dataset samples
chunk_size: Size of chunks
overlap: Overlap between chunks
"""
# Create collection
self.create_collection(
collection_name,
embedding_model_name,
metadata={
"chunking_strategy": chunking_strategy,
"chunk_size": chunk_size,
"overlap": overlap
}
)
# Get chunker
chunker = ChunkingFactory.create_chunker(chunking_strategy)
# Process documents
all_chunks = []
all_metadatas = []
print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
for idx, sample in enumerate(dataset_data):
# Use 'documents' list if available, otherwise fall back to 'context'
documents = sample.get("documents", [])
# If documents is empty, use context as fallback
if not documents:
context = sample.get("context", "")
if context:
documents = [context]
if not documents:
continue
# Process each document separately for better granularity
for doc_idx, document in enumerate(documents):
if not document or not str(document).strip():
continue
# Chunk each document
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
# Create metadata for each chunk
for chunk_idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadatas.append({
"doc_id": idx,
"doc_idx": doc_idx, # Track which document within the sample
"chunk_id": chunk_idx,
"question": sample.get("question", ""),
"answer": sample.get("answer", ""),
"dataset": sample.get("dataset", ""),
"total_docs": len(documents)
})
# Add all chunks to collection
self.add_documents(all_chunks, all_metadatas)
print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
def query(
self,
query_text: str,
n_results: int = 5,
filter_metadata: Optional[Dict] = None
) -> Dict:
"""Query the collection.
Args:
query_text: Query text
n_results: Number of results to return
filter_metadata: Metadata filter
Returns:
Query results
"""
if not self.current_collection:
raise ValueError("No collection selected.")
if not self.embedding_model:
raise ValueError("No embedding model loaded.")
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query_text)
# Query collection with retry logic
try:
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
except Exception as e:
if "default_tenant" in str(e):
print("Warning: Lost connection to ChromaDB, reconnecting...")
self.reconnect()
# Try again after reconnecting
results = self.current_collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=filter_metadata
)
else:
raise
return results
def get_retrieved_documents(
self,
query_text: str,
n_results: int = 5
) -> List[Dict]:
"""Get retrieved documents with metadata.
Args:
query_text: Query text
n_results: Number of results
Returns:
List of retrieved documents with metadata
"""
results = self.query(query_text, n_results)
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i] if 'distances' in results else None
})
return retrieved_docs
def delete_collection(self, collection_name: str):
"""Delete a collection.
Args:
collection_name: Name of collection to delete
"""
try:
self.client.delete_collection(collection_name)
print(f"Deleted collection: {collection_name}")
except Exception as e:
print(f"Error deleting collection: {str(e)}")
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
"""Get statistics for a collection.
Args:
collection_name: Name of collection (uses current if None)
Returns:
Dictionary with collection statistics
"""
if collection_name:
collection = self.client.get_collection(collection_name)
elif self.current_collection:
collection = self.current_collection
else:
raise ValueError("No collection specified or selected")
count = collection.count()
metadata = collection.metadata
return {
"name": collection.name,
"count": count,
"metadata": metadata
}