SmokeScan / rag /vectorstore.py
KinetoLabs's picture
Replace dual 8B with single 30B-A3B FP8 vision model
706520f
"""ChromaDB vector store for FDAM knowledge base.
Provides embedding and storage with metadata support.
Uses mock embeddings when MOCK_MODELS=true for local development.
"""
import hashlib
import logging
from typing import Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from config.settings import settings
from .chunker import Chunk
logger = logging.getLogger(__name__)
class MockEmbeddingFunction:
"""Mock embedding function for local development.
Generates deterministic pseudo-embeddings based on text hash.
Produces 2048-dimensional vectors (matches Qwen3-VL-Embedding-2B).
"""
EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate mock embeddings for a list of texts."""
return [self._embed_text(text) for text in input]
def _embed_text(self, text: str) -> list[float]:
"""Generate a deterministic pseudo-embedding from text.
Uses SHA-256 hash expanded to fill embedding dimensions.
L2 normalized to match real model output.
"""
import math
# Hash the text
text_hash = hashlib.sha256(text.encode("utf-8")).digest()
# Expand hash to fill embedding dimensions
embedding = []
for i in range(self.EMBEDDING_DIM):
# Use hash bytes cyclically, normalized to [-1, 1]
byte_val = text_hash[i % len(text_hash)]
normalized = (byte_val / 127.5) - 1.0
embedding.append(normalized)
# L2 normalize (matching real model behavior)
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
class SharedEmbeddingFunction:
"""Embedding function that uses the shared model from RealModelStack.
This avoids loading a duplicate embedding model - instead uses the
model already loaded by the pipeline at startup.
For ChromaDB compatibility, this wraps the model stack's embedding model.
"""
EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate embeddings using the shared model from model stack."""
from models.loader import get_models
model_stack = get_models()
# Use the shared embedding model (always loaded at startup)
return model_stack.embedding.embed_batch(input)
def get_embedding_function():
"""Get appropriate embedding function based on settings.
For real models, uses SharedEmbeddingFunction which wraps the
model stack's embedding model (no duplicate loading).
"""
if settings.mock_models:
return MockEmbeddingFunction()
return SharedEmbeddingFunction()
class ChromaVectorStore:
"""ChromaDB-based vector store for FDAM knowledge base."""
COLLECTION_NAME = "fdam_knowledge_base"
def __init__(
self,
persist_directory: Optional[str] = None,
embedding_function=None,
):
"""Initialize vector store.
Args:
persist_directory: Directory for ChromaDB persistence.
If None, uses in-memory storage.
embedding_function: Custom embedding function.
If None, uses appropriate default.
"""
self.persist_directory = persist_directory
# Initialize ChromaDB client
if persist_directory:
persist_path = Path(persist_directory)
persist_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"ChromaDB: using persistent storage at {persist_path}")
self.client = chromadb.PersistentClient(
path=str(persist_path),
settings=Settings(anonymized_telemetry=False),
)
else:
logger.debug("ChromaDB: using in-memory storage")
self.client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
# Set up embedding function
self.embedding_function = embedding_function or get_embedding_function()
embed_type = "mock" if settings.mock_models else "real"
logger.debug(f"ChromaDB: using {embed_type} embeddings")
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks")
def add_chunks(self, chunks: list[Chunk]) -> int:
"""Add chunks to the vector store.
Args:
chunks: List of Chunk objects to add
Returns:
Number of chunks added
"""
if not chunks:
return 0
ids = [chunk.id for chunk in chunks]
documents = [chunk.text for chunk in chunks]
metadatas = [chunk.to_metadata() for chunk in chunks]
# Generate embeddings
embeddings = self.embedding_function(documents)
# Add to collection
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
return len(chunks)
def query(
self,
query_text: str,
n_results: int = 5,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
) -> list[dict]:
"""Query the vector store.
Args:
query_text: Query text to search for
n_results: Number of results to return
where: Metadata filter (e.g., {"priority": "primary"})
where_document: Document content filter
Returns:
List of result dicts with keys: id, document, metadata, distance
"""
# Generate query embedding
query_embedding = self.embedding_function([query_text])[0]
# Query collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
where_document=where_document,
include=["documents", "metadatas", "distances"],
)
# Format results
formatted = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
formatted.append(
{
"id": chunk_id,
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
return formatted
def get_stats(self) -> dict:
"""Get collection statistics."""
count = self.collection.count()
# Get category distribution
categories = {}
priorities = {}
if count > 0:
# Sample all documents to get metadata distribution
all_results = self.collection.get(include=["metadatas"])
for metadata in all_results["metadatas"]:
cat = metadata.get("category", "unknown")
pri = metadata.get("priority", "unknown")
categories[cat] = categories.get(cat, 0) + 1
priorities[pri] = priorities.get(pri, 0) + 1
return {
"total_chunks": count,
"categories": categories,
"priorities": priorities,
"collection_name": self.COLLECTION_NAME,
"persist_directory": self.persist_directory,
}
def clear(self):
"""Clear all data from the collection."""
self.client.delete_collection(self.COLLECTION_NAME)
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def delete_by_source(self, source: str) -> int:
"""Delete all chunks from a specific source.
Args:
source: Source filename to delete
Returns:
Number of chunks deleted
"""
# Get IDs of chunks from this source
results = self.collection.get(
where={"source": source},
include=[],
)
if results["ids"]:
self.collection.delete(ids=results["ids"])
return len(results["ids"])
return 0