Spaces:
Paused
Paused
| """ChromaDB vector store for FDAM knowledge base. | |
| Provides embedding and storage with metadata support. | |
| Uses mock embeddings when MOCK_MODELS=true for local development. | |
| """ | |
| import hashlib | |
| import logging | |
| from typing import Optional | |
| from pathlib import Path | |
| import chromadb | |
| from chromadb.config import Settings | |
| from config.settings import settings | |
| from .chunker import Chunk | |
| logger = logging.getLogger(__name__) | |
| class MockEmbeddingFunction: | |
| """Mock embedding function for local development. | |
| Generates deterministic pseudo-embeddings based on text hash. | |
| Produces 4096-dimensional vectors (matches Qwen3-VL-Embedding-8B). | |
| """ | |
| EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size | |
| def __call__(self, input: list[str]) -> list[list[float]]: | |
| """Generate mock embeddings for a list of texts.""" | |
| return [self._embed_text(text) for text in input] | |
| def _embed_text(self, text: str) -> list[float]: | |
| """Generate a deterministic pseudo-embedding from text. | |
| Uses SHA-256 hash expanded to fill embedding dimensions. | |
| L2 normalized to match real model output. | |
| """ | |
| import math | |
| # Hash the text | |
| text_hash = hashlib.sha256(text.encode("utf-8")).digest() | |
| # Expand hash to fill embedding dimensions | |
| embedding = [] | |
| for i in range(self.EMBEDDING_DIM): | |
| # Use hash bytes cyclically, normalized to [-1, 1] | |
| byte_val = text_hash[i % len(text_hash)] | |
| normalized = (byte_val / 127.5) - 1.0 | |
| embedding.append(normalized) | |
| # L2 normalize (matching real model behavior) | |
| norm = math.sqrt(sum(x * x for x in embedding)) | |
| if norm > 0: | |
| embedding = [x / norm for x in embedding] | |
| return embedding | |
| class RealEmbeddingFunction: | |
| """Real embedding function using Qwen3-VL-Embedding-8B. | |
| Uses last-token pooling per official Qwen3-VL-Embedding implementation. | |
| Loaded on-demand when MOCK_MODELS=false. | |
| Reference: https://github.com/QwenLM/Qwen3-VL-Embedding | |
| """ | |
| EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| def _load_model(self): | |
| """Lazy load the embedding model.""" | |
| if self.model is not None: | |
| return | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| model_name = "Qwen/Qwen3-VL-Embedding-8B" | |
| logger.info(f"Loading embedding model: {model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| self.model.eval() | |
| def _pooling_last(hidden_state, attention_mask): | |
| """Extract the last valid token's hidden state. | |
| Official pooling method from Qwen3-VL-Embedding. | |
| Finds the last position where attention_mask == 1 and extracts that token. | |
| """ | |
| import torch | |
| flipped_tensor = attention_mask.flip(dims=[1]) | |
| last_one_positions = flipped_tensor.argmax(dim=1) | |
| col = attention_mask.shape[1] - last_one_positions - 1 | |
| row = torch.arange(hidden_state.shape[0], device=hidden_state.device) | |
| return hidden_state[row, col] | |
| def __call__(self, input: list[str]) -> list[list[float]]: | |
| """Generate embeddings for a list of texts using last-token pooling.""" | |
| self._load_model() | |
| import torch | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for text in input: | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True, | |
| ) | |
| # Note: With device_map="auto", transformers handles device routing internally | |
| # Do NOT call .to(device) - it breaks distributed models | |
| outputs = self.model(**inputs) | |
| # Use last-token pooling (official Qwen3-VL-Embedding method) | |
| attention_mask = inputs.get("attention_mask") | |
| if attention_mask is not None: | |
| embedding = self._pooling_last(outputs.last_hidden_state, attention_mask) | |
| else: | |
| # Fallback: use last token if no attention mask | |
| embedding = outputs.last_hidden_state[:, -1, :] | |
| # L2 normalize (per official implementation) | |
| embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) | |
| embeddings.append(embedding.squeeze().cpu().float().tolist()) | |
| return embeddings | |
| def get_embedding_function(): | |
| """Get appropriate embedding function based on settings.""" | |
| if settings.mock_models: | |
| return MockEmbeddingFunction() | |
| return RealEmbeddingFunction() | |
| class ChromaVectorStore: | |
| """ChromaDB-based vector store for FDAM knowledge base.""" | |
| COLLECTION_NAME = "fdam_knowledge_base" | |
| def __init__( | |
| self, | |
| persist_directory: Optional[str] = None, | |
| embedding_function=None, | |
| ): | |
| """Initialize vector store. | |
| Args: | |
| persist_directory: Directory for ChromaDB persistence. | |
| If None, uses in-memory storage. | |
| embedding_function: Custom embedding function. | |
| If None, uses appropriate default. | |
| """ | |
| self.persist_directory = persist_directory | |
| # Initialize ChromaDB client | |
| if persist_directory: | |
| persist_path = Path(persist_directory) | |
| persist_path.mkdir(parents=True, exist_ok=True) | |
| logger.debug(f"ChromaDB: using persistent storage at {persist_path}") | |
| self.client = chromadb.PersistentClient( | |
| path=str(persist_path), | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| else: | |
| logger.debug("ChromaDB: using in-memory storage") | |
| self.client = chromadb.Client( | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| # Set up embedding function | |
| self.embedding_function = embedding_function or get_embedding_function() | |
| embed_type = "mock" if settings.mock_models else "real" | |
| logger.debug(f"ChromaDB: using {embed_type} embeddings") | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks") | |
| def add_chunks(self, chunks: list[Chunk]) -> int: | |
| """Add chunks to the vector store. | |
| Args: | |
| chunks: List of Chunk objects to add | |
| Returns: | |
| Number of chunks added | |
| """ | |
| if not chunks: | |
| return 0 | |
| ids = [chunk.id for chunk in chunks] | |
| documents = [chunk.text for chunk in chunks] | |
| metadatas = [chunk.to_metadata() for chunk in chunks] | |
| # Generate embeddings | |
| embeddings = self.embedding_function(documents) | |
| # Add to collection | |
| self.collection.add( | |
| ids=ids, | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ) | |
| return len(chunks) | |
| def query( | |
| self, | |
| query_text: str, | |
| n_results: int = 5, | |
| where: Optional[dict] = None, | |
| where_document: Optional[dict] = None, | |
| ) -> list[dict]: | |
| """Query the vector store. | |
| Args: | |
| query_text: Query text to search for | |
| n_results: Number of results to return | |
| where: Metadata filter (e.g., {"priority": "primary"}) | |
| where_document: Document content filter | |
| Returns: | |
| List of result dicts with keys: id, document, metadata, distance | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_function([query_text])[0] | |
| # Query collection | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results, | |
| where=where, | |
| where_document=where_document, | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| # Format results | |
| formatted = [] | |
| if results["ids"] and results["ids"][0]: | |
| for i, chunk_id in enumerate(results["ids"][0]): | |
| formatted.append( | |
| { | |
| "id": chunk_id, | |
| "document": results["documents"][0][i], | |
| "metadata": results["metadatas"][0][i], | |
| "distance": results["distances"][0][i], | |
| } | |
| ) | |
| return formatted | |
| def get_stats(self) -> dict: | |
| """Get collection statistics.""" | |
| count = self.collection.count() | |
| # Get category distribution | |
| categories = {} | |
| priorities = {} | |
| if count > 0: | |
| # Sample all documents to get metadata distribution | |
| all_results = self.collection.get(include=["metadatas"]) | |
| for metadata in all_results["metadatas"]: | |
| cat = metadata.get("category", "unknown") | |
| pri = metadata.get("priority", "unknown") | |
| categories[cat] = categories.get(cat, 0) + 1 | |
| priorities[pri] = priorities.get(pri, 0) + 1 | |
| return { | |
| "total_chunks": count, | |
| "categories": categories, | |
| "priorities": priorities, | |
| "collection_name": self.COLLECTION_NAME, | |
| "persist_directory": self.persist_directory, | |
| } | |
| def clear(self): | |
| """Clear all data from the collection.""" | |
| self.client.delete_collection(self.COLLECTION_NAME) | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def delete_by_source(self, source: str) -> int: | |
| """Delete all chunks from a specific source. | |
| Args: | |
| source: Source filename to delete | |
| Returns: | |
| Number of chunks deleted | |
| """ | |
| # Get IDs of chunks from this source | |
| results = self.collection.get( | |
| where={"source": source}, | |
| include=[], | |
| ) | |
| if results["ids"]: | |
| self.collection.delete(ids=results["ids"]) | |
| return len(results["ids"]) | |
| return 0 | |