Spaces:
Paused
Paused
| """ChromaDB vector store for FDAM knowledge base. | |
| Provides embedding and storage with metadata support. | |
| Uses mock embeddings when MOCK_MODELS=true for local development. | |
| """ | |
| import hashlib | |
| import logging | |
| from typing import Optional | |
| from pathlib import Path | |
| import chromadb | |
| from chromadb.config import Settings | |
| from config.settings import settings | |
| from .chunker import Chunk | |
| logger = logging.getLogger(__name__) | |
| class MockEmbeddingFunction: | |
| """Mock embedding function for local development. | |
| Generates deterministic pseudo-embeddings based on text hash. | |
| Produces 2048-dimensional vectors (matches Qwen3-VL-Embedding-2B). | |
| """ | |
| EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size | |
| def __call__(self, input: list[str]) -> list[list[float]]: | |
| """Generate mock embeddings for a list of texts.""" | |
| return [self._embed_text(text) for text in input] | |
| def _embed_text(self, text: str) -> list[float]: | |
| """Generate a deterministic pseudo-embedding from text. | |
| Uses SHA-256 hash expanded to fill embedding dimensions. | |
| L2 normalized to match real model output. | |
| """ | |
| import math | |
| # Hash the text | |
| text_hash = hashlib.sha256(text.encode("utf-8")).digest() | |
| # Expand hash to fill embedding dimensions | |
| embedding = [] | |
| for i in range(self.EMBEDDING_DIM): | |
| # Use hash bytes cyclically, normalized to [-1, 1] | |
| byte_val = text_hash[i % len(text_hash)] | |
| normalized = (byte_val / 127.5) - 1.0 | |
| embedding.append(normalized) | |
| # L2 normalize (matching real model behavior) | |
| norm = math.sqrt(sum(x * x for x in embedding)) | |
| if norm > 0: | |
| embedding = [x / norm for x in embedding] | |
| return embedding | |
| class SharedEmbeddingFunction: | |
| """Embedding function that uses the shared model from RealModelStack. | |
| This avoids loading a duplicate embedding model - instead uses the | |
| model already loaded by the pipeline at startup. | |
| For ChromaDB compatibility, this wraps the model stack's embedding model. | |
| """ | |
| EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size | |
| def __call__(self, input: list[str]) -> list[list[float]]: | |
| """Generate embeddings using the shared model from model stack.""" | |
| from models.loader import get_models | |
| model_stack = get_models() | |
| # Use the shared embedding model (always loaded at startup) | |
| return model_stack.embedding.embed_batch(input) | |
| def get_embedding_function(): | |
| """Get appropriate embedding function based on settings. | |
| For real models, uses SharedEmbeddingFunction which wraps the | |
| model stack's embedding model (no duplicate loading). | |
| """ | |
| if settings.mock_models: | |
| return MockEmbeddingFunction() | |
| return SharedEmbeddingFunction() | |
| class ChromaVectorStore: | |
| """ChromaDB-based vector store for FDAM knowledge base.""" | |
| COLLECTION_NAME = "fdam_knowledge_base" | |
| def __init__( | |
| self, | |
| persist_directory: Optional[str] = None, | |
| embedding_function=None, | |
| ): | |
| """Initialize vector store. | |
| Args: | |
| persist_directory: Directory for ChromaDB persistence. | |
| If None, uses in-memory storage. | |
| embedding_function: Custom embedding function. | |
| If None, uses appropriate default. | |
| """ | |
| self.persist_directory = persist_directory | |
| # Initialize ChromaDB client | |
| if persist_directory: | |
| persist_path = Path(persist_directory) | |
| persist_path.mkdir(parents=True, exist_ok=True) | |
| logger.debug(f"ChromaDB: using persistent storage at {persist_path}") | |
| self.client = chromadb.PersistentClient( | |
| path=str(persist_path), | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| else: | |
| logger.debug("ChromaDB: using in-memory storage") | |
| self.client = chromadb.Client( | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| # Set up embedding function | |
| self.embedding_function = embedding_function or get_embedding_function() | |
| embed_type = "mock" if settings.mock_models else "real" | |
| logger.debug(f"ChromaDB: using {embed_type} embeddings") | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks") | |
| def add_chunks(self, chunks: list[Chunk]) -> int: | |
| """Add chunks to the vector store. | |
| Args: | |
| chunks: List of Chunk objects to add | |
| Returns: | |
| Number of chunks added | |
| """ | |
| if not chunks: | |
| return 0 | |
| ids = [chunk.id for chunk in chunks] | |
| documents = [chunk.text for chunk in chunks] | |
| metadatas = [chunk.to_metadata() for chunk in chunks] | |
| # Generate embeddings | |
| embeddings = self.embedding_function(documents) | |
| # Add to collection | |
| self.collection.add( | |
| ids=ids, | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ) | |
| return len(chunks) | |
| def query( | |
| self, | |
| query_text: str, | |
| n_results: int = 5, | |
| where: Optional[dict] = None, | |
| where_document: Optional[dict] = None, | |
| ) -> list[dict]: | |
| """Query the vector store. | |
| Args: | |
| query_text: Query text to search for | |
| n_results: Number of results to return | |
| where: Metadata filter (e.g., {"priority": "primary"}) | |
| where_document: Document content filter | |
| Returns: | |
| List of result dicts with keys: id, document, metadata, distance | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_function([query_text])[0] | |
| # Query collection | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results, | |
| where=where, | |
| where_document=where_document, | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| # Format results | |
| formatted = [] | |
| if results["ids"] and results["ids"][0]: | |
| for i, chunk_id in enumerate(results["ids"][0]): | |
| formatted.append( | |
| { | |
| "id": chunk_id, | |
| "document": results["documents"][0][i], | |
| "metadata": results["metadatas"][0][i], | |
| "distance": results["distances"][0][i], | |
| } | |
| ) | |
| return formatted | |
| def get_stats(self) -> dict: | |
| """Get collection statistics.""" | |
| count = self.collection.count() | |
| # Get category distribution | |
| categories = {} | |
| priorities = {} | |
| if count > 0: | |
| # Sample all documents to get metadata distribution | |
| all_results = self.collection.get(include=["metadatas"]) | |
| for metadata in all_results["metadatas"]: | |
| cat = metadata.get("category", "unknown") | |
| pri = metadata.get("priority", "unknown") | |
| categories[cat] = categories.get(cat, 0) + 1 | |
| priorities[pri] = priorities.get(pri, 0) + 1 | |
| return { | |
| "total_chunks": count, | |
| "categories": categories, | |
| "priorities": priorities, | |
| "collection_name": self.COLLECTION_NAME, | |
| "persist_directory": self.persist_directory, | |
| } | |
| def clear(self): | |
| """Clear all data from the collection.""" | |
| self.client.delete_collection(self.COLLECTION_NAME) | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def delete_by_source(self, source: str) -> int: | |
| """Delete all chunks from a specific source. | |
| Args: | |
| source: Source filename to delete | |
| Returns: | |
| Number of chunks deleted | |
| """ | |
| # Get IDs of chunks from this source | |
| results = self.collection.get( | |
| where={"source": source}, | |
| include=[], | |
| ) | |
| if results["ids"]: | |
| self.collection.delete(ids=results["ids"]) | |
| return len(results["ids"]) | |
| return 0 | |