|
|
"""ChromaDB integration for vector storage and retrieval.""" |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
import uuid |
|
|
import os |
|
|
from embedding_models import EmbeddingFactory, EmbeddingModel |
|
|
from chunking_strategies import ChunkingFactory |
|
|
import json |
|
|
|
|
|
|
|
|
class ChromaDBManager: |
|
|
"""Manager for ChromaDB operations.""" |
|
|
|
|
|
def __init__(self, persist_directory: str = "./chroma_db"): |
|
|
"""Initialize ChromaDB manager. |
|
|
|
|
|
Args: |
|
|
persist_directory: Directory to persist ChromaDB data |
|
|
""" |
|
|
self.persist_directory = persist_directory |
|
|
os.makedirs(persist_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
self.client = chromadb.PersistentClient( |
|
|
path=persist_directory, |
|
|
settings=Settings( |
|
|
anonymized_telemetry=False, |
|
|
allow_reset=True |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not create persistent client: {e}") |
|
|
print("Falling back to regular client...") |
|
|
self.client = chromadb.Client(Settings( |
|
|
persist_directory=persist_directory, |
|
|
anonymized_telemetry=False, |
|
|
allow_reset=True |
|
|
)) |
|
|
|
|
|
self.embedding_model = None |
|
|
self.current_collection = None |
|
|
|
|
|
def reconnect(self): |
|
|
"""Reconnect to ChromaDB in case of connection loss.""" |
|
|
try: |
|
|
self.client = chromadb.PersistentClient( |
|
|
path=self.persist_directory, |
|
|
settings=Settings( |
|
|
anonymized_telemetry=False, |
|
|
allow_reset=True |
|
|
) |
|
|
) |
|
|
print("✅ Reconnected to ChromaDB") |
|
|
except Exception as e: |
|
|
print(f"Error reconnecting: {e}") |
|
|
|
|
|
def create_collection( |
|
|
self, |
|
|
collection_name: str, |
|
|
embedding_model_name: str, |
|
|
metadata: Optional[Dict] = None |
|
|
) -> chromadb.Collection: |
|
|
"""Create a new collection. |
|
|
|
|
|
Args: |
|
|
collection_name: Name of the collection |
|
|
embedding_model_name: Name of the embedding model |
|
|
metadata: Additional metadata for the collection |
|
|
|
|
|
Returns: |
|
|
ChromaDB collection |
|
|
""" |
|
|
|
|
|
try: |
|
|
self.client.delete_collection(collection_name) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name) |
|
|
self.embedding_model.load_model() |
|
|
|
|
|
|
|
|
collection_metadata = { |
|
|
"embedding_model": embedding_model_name, |
|
|
"hnsw:space": "cosine" |
|
|
} |
|
|
if metadata: |
|
|
collection_metadata.update(metadata) |
|
|
|
|
|
self.current_collection = self.client.create_collection( |
|
|
name=collection_name, |
|
|
metadata=collection_metadata |
|
|
) |
|
|
|
|
|
print(f"Created collection: {collection_name}") |
|
|
return self.current_collection |
|
|
|
|
|
def get_collection(self, collection_name: str) -> chromadb.Collection: |
|
|
"""Get an existing collection. |
|
|
|
|
|
Args: |
|
|
collection_name: Name of the collection |
|
|
|
|
|
Returns: |
|
|
ChromaDB collection |
|
|
""" |
|
|
self.current_collection = self.client.get_collection(collection_name) |
|
|
|
|
|
|
|
|
metadata = self.current_collection.metadata |
|
|
if "embedding_model" in metadata: |
|
|
self.embedding_model = EmbeddingFactory.create_embedding_model( |
|
|
metadata["embedding_model"] |
|
|
) |
|
|
self.embedding_model.load_model() |
|
|
|
|
|
return self.current_collection |
|
|
|
|
|
def list_collections(self) -> List[str]: |
|
|
"""List all collections. |
|
|
|
|
|
Returns: |
|
|
List of collection names |
|
|
""" |
|
|
collections = self.client.list_collections() |
|
|
return [col.name for col in collections] |
|
|
|
|
|
def clear_all_collections(self) -> int: |
|
|
"""Delete all collections from the database. |
|
|
|
|
|
Returns: |
|
|
Number of collections deleted |
|
|
""" |
|
|
collections = self.list_collections() |
|
|
count = 0 |
|
|
|
|
|
for collection_name in collections: |
|
|
try: |
|
|
self.client.delete_collection(collection_name) |
|
|
print(f"Deleted collection: {collection_name}") |
|
|
count += 1 |
|
|
except Exception as e: |
|
|
print(f"Error deleting collection {collection_name}: {e}") |
|
|
|
|
|
self.current_collection = None |
|
|
self.embedding_model = None |
|
|
print(f"✅ Cleared {count} collections") |
|
|
return count |
|
|
|
|
|
def delete_collection(self, collection_name: str) -> bool: |
|
|
"""Delete a specific collection. |
|
|
|
|
|
Args: |
|
|
collection_name: Name of the collection to delete |
|
|
|
|
|
Returns: |
|
|
True if deleted successfully, False otherwise |
|
|
""" |
|
|
try: |
|
|
self.client.delete_collection(collection_name) |
|
|
if self.current_collection and self.current_collection.name == collection_name: |
|
|
self.current_collection = None |
|
|
self.embedding_model = None |
|
|
print(f"✅ Deleted collection: {collection_name}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ Error deleting collection: {e}") |
|
|
return False |
|
|
|
|
|
def add_documents( |
|
|
self, |
|
|
documents: List[str], |
|
|
metadatas: Optional[List[Dict]] = None, |
|
|
ids: Optional[List[str]] = None, |
|
|
batch_size: int = 100 |
|
|
): |
|
|
"""Add documents to the current collection. |
|
|
|
|
|
Args: |
|
|
documents: List of document texts |
|
|
metadatas: List of metadata dictionaries |
|
|
ids: List of document IDs |
|
|
batch_size: Batch size for processing |
|
|
""" |
|
|
if not self.current_collection: |
|
|
raise ValueError("No collection selected. Create or get a collection first.") |
|
|
|
|
|
if not self.embedding_model: |
|
|
raise ValueError("No embedding model loaded.") |
|
|
|
|
|
|
|
|
if ids is None: |
|
|
ids = [str(uuid.uuid4()) for _ in documents] |
|
|
|
|
|
|
|
|
if metadatas is None: |
|
|
metadatas = [{"index": i} for i in range(len(documents))] |
|
|
|
|
|
|
|
|
total_docs = len(documents) |
|
|
print(f"Adding {total_docs} documents to collection...") |
|
|
|
|
|
for i in range(0, total_docs, batch_size): |
|
|
batch_docs = documents[i:i + batch_size] |
|
|
batch_ids = ids[i:i + batch_size] |
|
|
batch_metadatas = metadatas[i:i + batch_size] |
|
|
|
|
|
|
|
|
embeddings = self.embedding_model.embed_documents(batch_docs) |
|
|
|
|
|
|
|
|
self.current_collection.add( |
|
|
documents=batch_docs, |
|
|
embeddings=embeddings.tolist(), |
|
|
metadatas=batch_metadatas, |
|
|
ids=batch_ids |
|
|
) |
|
|
|
|
|
print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}") |
|
|
|
|
|
print(f"Successfully added {total_docs} documents") |
|
|
|
|
|
def load_dataset_into_collection( |
|
|
self, |
|
|
collection_name: str, |
|
|
embedding_model_name: str, |
|
|
chunking_strategy: str, |
|
|
dataset_data: List[Dict], |
|
|
chunk_size: int = 512, |
|
|
overlap: int = 50 |
|
|
): |
|
|
"""Load a dataset into a new collection with chunking. |
|
|
|
|
|
Args: |
|
|
collection_name: Name for the new collection |
|
|
embedding_model_name: Embedding model to use |
|
|
chunking_strategy: Chunking strategy to use |
|
|
dataset_data: List of dataset samples |
|
|
chunk_size: Size of chunks |
|
|
overlap: Overlap between chunks |
|
|
""" |
|
|
|
|
|
self.create_collection( |
|
|
collection_name, |
|
|
embedding_model_name, |
|
|
metadata={ |
|
|
"chunking_strategy": chunking_strategy, |
|
|
"chunk_size": chunk_size, |
|
|
"overlap": overlap |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
chunker = ChunkingFactory.create_chunker(chunking_strategy) |
|
|
|
|
|
|
|
|
all_chunks = [] |
|
|
all_metadatas = [] |
|
|
|
|
|
print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...") |
|
|
|
|
|
for idx, sample in enumerate(dataset_data): |
|
|
|
|
|
documents = sample.get("documents", []) |
|
|
|
|
|
|
|
|
if not documents: |
|
|
context = sample.get("context", "") |
|
|
if context: |
|
|
documents = [context] |
|
|
|
|
|
if not documents: |
|
|
continue |
|
|
|
|
|
|
|
|
for doc_idx, document in enumerate(documents): |
|
|
if not document or not str(document).strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
chunks = chunker.chunk_text(str(document), chunk_size, overlap) |
|
|
|
|
|
|
|
|
for chunk_idx, chunk in enumerate(chunks): |
|
|
all_chunks.append(chunk) |
|
|
all_metadatas.append({ |
|
|
"doc_id": idx, |
|
|
"doc_idx": doc_idx, |
|
|
"chunk_id": chunk_idx, |
|
|
"question": sample.get("question", ""), |
|
|
"answer": sample.get("answer", ""), |
|
|
"dataset": sample.get("dataset", ""), |
|
|
"total_docs": len(documents) |
|
|
}) |
|
|
|
|
|
|
|
|
self.add_documents(all_chunks, all_metadatas) |
|
|
|
|
|
print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples") |
|
|
|
|
|
def query( |
|
|
self, |
|
|
query_text: str, |
|
|
n_results: int = 5, |
|
|
filter_metadata: Optional[Dict] = None |
|
|
) -> Dict: |
|
|
"""Query the collection. |
|
|
|
|
|
Args: |
|
|
query_text: Query text |
|
|
n_results: Number of results to return |
|
|
filter_metadata: Metadata filter |
|
|
|
|
|
Returns: |
|
|
Query results |
|
|
""" |
|
|
if not self.current_collection: |
|
|
raise ValueError("No collection selected.") |
|
|
|
|
|
if not self.embedding_model: |
|
|
raise ValueError("No embedding model loaded.") |
|
|
|
|
|
|
|
|
query_embedding = self.embedding_model.embed_query(query_text) |
|
|
|
|
|
|
|
|
try: |
|
|
results = self.current_collection.query( |
|
|
query_embeddings=[query_embedding.tolist()], |
|
|
n_results=n_results, |
|
|
where=filter_metadata |
|
|
) |
|
|
except Exception as e: |
|
|
if "default_tenant" in str(e): |
|
|
print("Warning: Lost connection to ChromaDB, reconnecting...") |
|
|
self.reconnect() |
|
|
|
|
|
results = self.current_collection.query( |
|
|
query_embeddings=[query_embedding.tolist()], |
|
|
n_results=n_results, |
|
|
where=filter_metadata |
|
|
) |
|
|
else: |
|
|
raise |
|
|
|
|
|
return results |
|
|
|
|
|
def get_retrieved_documents( |
|
|
self, |
|
|
query_text: str, |
|
|
n_results: int = 5 |
|
|
) -> List[Dict]: |
|
|
"""Get retrieved documents with metadata. |
|
|
|
|
|
Args: |
|
|
query_text: Query text |
|
|
n_results: Number of results |
|
|
|
|
|
Returns: |
|
|
List of retrieved documents with metadata |
|
|
""" |
|
|
results = self.query(query_text, n_results) |
|
|
|
|
|
retrieved_docs = [] |
|
|
for i in range(len(results['documents'][0])): |
|
|
retrieved_docs.append({ |
|
|
"document": results['documents'][0][i], |
|
|
"metadata": results['metadatas'][0][i], |
|
|
"distance": results['distances'][0][i] if 'distances' in results else None |
|
|
}) |
|
|
|
|
|
return retrieved_docs |
|
|
|
|
|
def delete_collection(self, collection_name: str): |
|
|
"""Delete a collection. |
|
|
|
|
|
Args: |
|
|
collection_name: Name of collection to delete |
|
|
""" |
|
|
try: |
|
|
self.client.delete_collection(collection_name) |
|
|
print(f"Deleted collection: {collection_name}") |
|
|
except Exception as e: |
|
|
print(f"Error deleting collection: {str(e)}") |
|
|
|
|
|
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict: |
|
|
"""Get statistics for a collection. |
|
|
|
|
|
Args: |
|
|
collection_name: Name of collection (uses current if None) |
|
|
|
|
|
Returns: |
|
|
Dictionary with collection statistics |
|
|
""" |
|
|
if collection_name: |
|
|
collection = self.client.get_collection(collection_name) |
|
|
elif self.current_collection: |
|
|
collection = self.current_collection |
|
|
else: |
|
|
raise ValueError("No collection specified or selected") |
|
|
|
|
|
count = collection.count() |
|
|
metadata = collection.metadata |
|
|
|
|
|
return { |
|
|
"name": collection.name, |
|
|
"count": count, |
|
|
"metadata": metadata |
|
|
} |
|
|
|