zeta / src /embedding /vector_store.py
rodrigo-moonray
Deploy zeta-only embeddings (NV-Embed-v2 + E5-small)
9b457ed
"""
ChromaDB vector storage interface.
This module provides a clean interface to ChromaDB for storing and retrieving
document chunks with their embeddings and metadata.
"""
import chromadb
from typing import List, Optional
import numpy as np
import json
from datetime import datetime
from src.config.settings import get_settings, get_collection_name_for_model, EMBEDDING_MODELS
from src.utils.logging import get_logger
from src.ingestion.models import Chunk
logger = get_logger(__name__)
class VectorStore:
"""ChromaDB interface for vector storage."""
def __init__(self, embedding_model: Optional[str] = None):
"""
Initialize vector store with settings from configuration.
Args:
embedding_model: Optional embedding model ID. If provided, uses model-specific collection.
"""
settings = get_settings()
self.persist_dir = settings.chroma_persist_dir
self._base_collection_name = settings.chroma_collection_name
self._embedding_model = embedding_model or settings.embedding_model
# Use model-specific collection name
self.collection_name = get_collection_name_for_model(
self._embedding_model,
self._base_collection_name
)
self._client = None
self._collection = None
@property
def client(self):
"""
Lazy initialize ChromaDB client.
Returns:
chromadb.Client: ChromaDB client instance
"""
if self._client is None:
logger.info(f"Initializing ChromaDB client: {self.persist_dir}")
self._client = chromadb.PersistentClient(path=self.persist_dir)
logger.debug(f"ChromaDB client initialized")
return self._client
def get_collection(self):
"""
Get or create the collection.
Returns:
chromadb.Collection: Collection instance
"""
if self._collection is None:
self._collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Hierarchical PDF chunks with embeddings"}
)
logger.info(f"Collection loaded: {self.collection_name}")
return self._collection
def add_chunks(self, chunks: List[Chunk], embeddings: np.ndarray):
"""
Add chunks with embeddings to ChromaDB.
Args:
chunks: List of chunks to store
embeddings: Numpy array of embeddings (num_chunks x embedding_dim)
"""
if len(chunks) != len(embeddings):
raise ValueError(f"Number of chunks ({len(chunks)}) != number of embeddings ({len(embeddings)})")
collection = self.get_collection()
# Prepare data for ChromaDB
ids = [str(chunk.chunk_id) for chunk in chunks]
documents = [chunk.text for chunk in chunks]
metadatas = [self._prepare_metadata(chunk) for chunk in chunks]
logger.info(f"Adding {len(chunks)} chunks to ChromaDB")
# Add to collection
collection.add(
ids=ids,
embeddings=embeddings.tolist(),
documents=documents,
metadatas=metadatas
)
logger.info(f"Successfully added {len(chunks)} chunks")
def _prepare_metadata(self, chunk: Chunk) -> dict:
"""
Prepare metadata for ChromaDB storage.
ChromaDB metadata can only contain: str, int, float, bool.
Lists must be JSON-encoded.
Args:
chunk: Chunk to extract metadata from
Returns:
dict: Metadata dictionary
"""
return {
"chunk_id": str(chunk.chunk_id),
"document_id": str(chunk.document_id),
"parent_id": str(chunk.parent_id) if chunk.parent_id else "",
"chunk_type": chunk.chunk_type,
"token_count": chunk.token_count,
"chunk_index": chunk.chunk_index,
"page_numbers": json.dumps(chunk.page_numbers),
"start_char": chunk.start_char,
"end_char": chunk.end_char,
"file_hash": chunk.file_hash,
"filename": chunk.filename,
}
def document_exists(self, file_hash: str) -> bool:
"""
Check if document with given hash already exists.
Args:
file_hash: SHA256 hash of document
Returns:
bool: True if document exists
"""
collection = self.get_collection()
try:
# Try to query for any chunk with this file hash
results = collection.get(
where={"file_hash": file_hash},
limit=1
)
exists = len(results['ids']) > 0
if exists:
logger.debug(f"Document with hash {file_hash[:8]}... already exists")
return exists
except Exception as e:
# If metadata field doesn't exist, document doesn't exist
logger.debug(f"Document check failed: {e}")
return False
def get_chunk(self, chunk_id: str) -> Optional[dict]:
"""
Retrieve a specific chunk by ID.
Args:
chunk_id: UUID of chunk to retrieve
Returns:
Optional[dict]: Chunk data or None if not found
"""
collection = self.get_collection()
try:
results = collection.get(
ids=[chunk_id],
include=["documents", "metadatas", "embeddings"]
)
if len(results['ids']) > 0:
return {
"id": results['ids'][0],
"document": results['documents'][0],
"metadata": results['metadatas'][0],
"embedding": results['embeddings'][0] if results['embeddings'] else None
}
return None
except Exception as e:
logger.error(f"Failed to retrieve chunk {chunk_id}: {e}")
return None
def delete_document(self, document_id: str):
"""
Delete all chunks for a document.
Args:
document_id: UUID of document to delete
"""
collection = self.get_collection()
try:
collection.delete(
where={"document_id": document_id}
)
logger.info(f"Deleted all chunks for document: {document_id}")
except Exception as e:
logger.error(f"Failed to delete document {document_id}: {e}")
raise
def get_collection_stats(self) -> dict:
"""
Get statistics about the collection.
Returns:
dict: Collection statistics
"""
collection = self.get_collection()
try:
count = collection.count()
return {
"name": self.collection_name,
"total_chunks": count,
"persist_dir": self.persist_dir,
"embedding_model": self._embedding_model,
}
except Exception as e:
logger.error(f"Failed to get collection stats: {e}")
return {}
def list_all_collections(self) -> List[dict]:
"""
List all available collections with their stats.
Returns:
List[dict]: List of collection info dictionaries
"""
collections = []
settings = get_settings()
for model_id, model_config in EMBEDDING_MODELS.items():
collection_name = get_collection_name_for_model(
model_id,
self._base_collection_name
)
try:
coll = self.client.get_collection(name=collection_name)
count = coll.count()
collections.append({
"collection_name": collection_name,
"embedding_model": model_id,
"model_name": model_config.get("name", model_id),
"dimensions": model_config.get("dimensions"),
"total_chunks": count,
"is_active": model_id == self._embedding_model,
})
except Exception:
# Collection doesn't exist yet
collections.append({
"collection_name": collection_name,
"embedding_model": model_id,
"model_name": model_config.get("name", model_id),
"dimensions": model_config.get("dimensions"),
"total_chunks": 0,
"is_active": model_id == self._embedding_model,
})
return collections
def switch_collection(self, embedding_model: str):
"""
Switch to a different collection based on embedding model.
Args:
embedding_model: Embedding model ID to switch to
"""
self._embedding_model = embedding_model
self.collection_name = get_collection_name_for_model(
embedding_model,
self._base_collection_name
)
self._collection = None # Reset cached collection
logger.info(f"Switched to collection: {self.collection_name}")
def query(
self,
query_embedding: np.ndarray,
top_k: int = 10,
filter_filenames: Optional[List[str]] = None,
) -> dict:
"""
Query the collection with an embedding.
Args:
query_embedding: Query embedding vector
top_k: Number of results to return
filter_filenames: Optional list of filenames to filter results
Returns:
dict: Query results with ids, documents, metadatas, and distances
"""
collection = self.get_collection()
try:
# Build where clause for filtering
where_clause = None
if filter_filenames:
if len(filter_filenames) == 1:
where_clause = {"filename": filter_filenames[0]}
else:
where_clause = {"filename": {"$in": filter_filenames}}
results = collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=top_k,
include=["documents", "metadatas", "distances"],
where=where_clause,
)
return results
except Exception as e:
logger.error(f"Query failed: {e}")
return {"ids": [], "documents": [], "metadatas": [], "distances": []}