"""ChromaDB vector store for FDAM knowledge base. Provides embedding and storage with metadata support. Uses mock embeddings when MOCK_MODELS=true for local development. """ import hashlib import logging from typing import Optional from pathlib import Path import chromadb from chromadb.config import Settings from config.settings import settings from .chunker import Chunk logger = logging.getLogger(__name__) class MockEmbeddingFunction: """Mock embedding function for local development. Generates deterministic pseudo-embeddings based on text hash. Produces 2048-dimensional vectors (matches Qwen3-VL-Embedding-2B). """ EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size def __call__(self, input: list[str]) -> list[list[float]]: """Generate mock embeddings for a list of texts.""" return [self._embed_text(text) for text in input] def _embed_text(self, text: str) -> list[float]: """Generate a deterministic pseudo-embedding from text. Uses SHA-256 hash expanded to fill embedding dimensions. L2 normalized to match real model output. """ import math # Hash the text text_hash = hashlib.sha256(text.encode("utf-8")).digest() # Expand hash to fill embedding dimensions embedding = [] for i in range(self.EMBEDDING_DIM): # Use hash bytes cyclically, normalized to [-1, 1] byte_val = text_hash[i % len(text_hash)] normalized = (byte_val / 127.5) - 1.0 embedding.append(normalized) # L2 normalize (matching real model behavior) norm = math.sqrt(sum(x * x for x in embedding)) if norm > 0: embedding = [x / norm for x in embedding] return embedding class SharedEmbeddingFunction: """Embedding function that uses the shared model from RealModelStack. This avoids loading a duplicate embedding model - instead uses the model already loaded by the pipeline at startup. For ChromaDB compatibility, this wraps the model stack's embedding model. """ EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size def __call__(self, input: list[str]) -> list[list[float]]: """Generate embeddings using the shared model from model stack.""" from models.loader import get_models model_stack = get_models() # Use the shared embedding model (always loaded at startup) return model_stack.embedding.embed_batch(input) def get_embedding_function(): """Get appropriate embedding function based on settings. For real models, uses SharedEmbeddingFunction which wraps the model stack's embedding model (no duplicate loading). """ if settings.mock_models: return MockEmbeddingFunction() return SharedEmbeddingFunction() class ChromaVectorStore: """ChromaDB-based vector store for FDAM knowledge base.""" COLLECTION_NAME = "fdam_knowledge_base" def __init__( self, persist_directory: Optional[str] = None, embedding_function=None, ): """Initialize vector store. Args: persist_directory: Directory for ChromaDB persistence. If None, uses in-memory storage. embedding_function: Custom embedding function. If None, uses appropriate default. """ self.persist_directory = persist_directory # Initialize ChromaDB client if persist_directory: persist_path = Path(persist_directory) persist_path.mkdir(parents=True, exist_ok=True) logger.debug(f"ChromaDB: using persistent storage at {persist_path}") self.client = chromadb.PersistentClient( path=str(persist_path), settings=Settings(anonymized_telemetry=False), ) else: logger.debug("ChromaDB: using in-memory storage") self.client = chromadb.Client( settings=Settings(anonymized_telemetry=False), ) # Set up embedding function self.embedding_function = embedding_function or get_embedding_function() embed_type = "mock" if settings.mock_models else "real" logger.debug(f"ChromaDB: using {embed_type} embeddings") # Get or create collection self.collection = self.client.get_or_create_collection( name=self.COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks") def add_chunks(self, chunks: list[Chunk]) -> int: """Add chunks to the vector store. Args: chunks: List of Chunk objects to add Returns: Number of chunks added """ if not chunks: return 0 ids = [chunk.id for chunk in chunks] documents = [chunk.text for chunk in chunks] metadatas = [chunk.to_metadata() for chunk in chunks] # Generate embeddings embeddings = self.embedding_function(documents) # Add to collection self.collection.add( ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas, ) return len(chunks) def query( self, query_text: str, n_results: int = 5, where: Optional[dict] = None, where_document: Optional[dict] = None, ) -> list[dict]: """Query the vector store. Args: query_text: Query text to search for n_results: Number of results to return where: Metadata filter (e.g., {"priority": "primary"}) where_document: Document content filter Returns: List of result dicts with keys: id, document, metadata, distance """ # Generate query embedding query_embedding = self.embedding_function([query_text])[0] # Query collection results = self.collection.query( query_embeddings=[query_embedding], n_results=n_results, where=where, where_document=where_document, include=["documents", "metadatas", "distances"], ) # Format results formatted = [] if results["ids"] and results["ids"][0]: for i, chunk_id in enumerate(results["ids"][0]): formatted.append( { "id": chunk_id, "document": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], } ) return formatted def get_stats(self) -> dict: """Get collection statistics.""" count = self.collection.count() # Get category distribution categories = {} priorities = {} if count > 0: # Sample all documents to get metadata distribution all_results = self.collection.get(include=["metadatas"]) for metadata in all_results["metadatas"]: cat = metadata.get("category", "unknown") pri = metadata.get("priority", "unknown") categories[cat] = categories.get(cat, 0) + 1 priorities[pri] = priorities.get(pri, 0) + 1 return { "total_chunks": count, "categories": categories, "priorities": priorities, "collection_name": self.COLLECTION_NAME, "persist_directory": self.persist_directory, } def clear(self): """Clear all data from the collection.""" self.client.delete_collection(self.COLLECTION_NAME) self.collection = self.client.get_or_create_collection( name=self.COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) def delete_by_source(self, source: str) -> int: """Delete all chunks from a specific source. Args: source: Source filename to delete Returns: Number of chunks deleted """ # Get IDs of chunks from this source results = self.collection.get( where={"source": source}, include=[], ) if results["ids"]: self.collection.delete(ids=results["ids"]) return len(results["ids"]) return 0