SmokeScan / rag /vectorstore.py
KinetoLabs's picture
Initial commit: FDAM AI Pipeline v4.0.1
88bdcff
raw
history blame
8.83 kB
"""ChromaDB vector store for FDAM knowledge base.
Provides embedding and storage with metadata support.
Uses mock embeddings when MOCK_MODELS=true for local development.
"""
import hashlib
from typing import Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from config.settings import settings
from .chunker import Chunk
class MockEmbeddingFunction:
"""Mock embedding function for local development.
Generates deterministic pseudo-embeddings based on text hash.
Produces 384-dimensional vectors (matches common embedding models).
"""
EMBEDDING_DIM = 384
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate mock embeddings for a list of texts."""
return [self._embed_text(text) for text in input]
def _embed_text(self, text: str) -> list[float]:
"""Generate a deterministic pseudo-embedding from text.
Uses SHA-256 hash expanded to fill embedding dimensions.
Not semantically meaningful but provides consistent behavior.
"""
# Hash the text
text_hash = hashlib.sha256(text.encode("utf-8")).digest()
# Expand hash to fill embedding dimensions
embedding = []
for i in range(self.EMBEDDING_DIM):
# Use hash bytes cyclically, normalized to [-1, 1]
byte_val = text_hash[i % len(text_hash)]
normalized = (byte_val / 127.5) - 1.0
embedding.append(normalized)
return embedding
class RealEmbeddingFunction:
"""Real embedding function using Qwen3-VL-Embedding-8B.
Loaded on-demand when MOCK_MODELS=false.
"""
EMBEDDING_DIM = 4096 # Qwen embedding dimension
def __init__(self):
self.model = None
self.tokenizer = None
def _load_model(self):
"""Lazy load the embedding model."""
if self.model is not None:
return
import torch
from transformers import AutoModel, AutoTokenizer
model_name = "Qwen/Qwen3-VL-Embedding-8B"
print(f"Loading embedding model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
self.model.eval()
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate embeddings for a list of texts."""
self._load_model()
import torch
embeddings = []
with torch.no_grad():
for text in input:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
# Use mean pooling over sequence
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
embeddings.append(embedding.cpu().float().tolist())
return embeddings
def get_embedding_function():
"""Get appropriate embedding function based on settings."""
if settings.mock_models:
return MockEmbeddingFunction()
return RealEmbeddingFunction()
class ChromaVectorStore:
"""ChromaDB-based vector store for FDAM knowledge base."""
COLLECTION_NAME = "fdam_knowledge_base"
def __init__(
self,
persist_directory: Optional[str] = None,
embedding_function=None,
):
"""Initialize vector store.
Args:
persist_directory: Directory for ChromaDB persistence.
If None, uses in-memory storage.
embedding_function: Custom embedding function.
If None, uses appropriate default.
"""
self.persist_directory = persist_directory
# Initialize ChromaDB client
if persist_directory:
persist_path = Path(persist_directory)
persist_path.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(
path=str(persist_path),
settings=Settings(anonymized_telemetry=False),
)
else:
self.client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
# Set up embedding function
self.embedding_function = embedding_function or get_embedding_function()
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def add_chunks(self, chunks: list[Chunk]) -> int:
"""Add chunks to the vector store.
Args:
chunks: List of Chunk objects to add
Returns:
Number of chunks added
"""
if not chunks:
return 0
ids = [chunk.id for chunk in chunks]
documents = [chunk.text for chunk in chunks]
metadatas = [chunk.to_metadata() for chunk in chunks]
# Generate embeddings
embeddings = self.embedding_function(documents)
# Add to collection
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
return len(chunks)
def query(
self,
query_text: str,
n_results: int = 5,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
) -> list[dict]:
"""Query the vector store.
Args:
query_text: Query text to search for
n_results: Number of results to return
where: Metadata filter (e.g., {"priority": "primary"})
where_document: Document content filter
Returns:
List of result dicts with keys: id, document, metadata, distance
"""
# Generate query embedding
query_embedding = self.embedding_function([query_text])[0]
# Query collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
where_document=where_document,
include=["documents", "metadatas", "distances"],
)
# Format results
formatted = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
formatted.append(
{
"id": chunk_id,
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
return formatted
def get_stats(self) -> dict:
"""Get collection statistics."""
count = self.collection.count()
# Get category distribution
categories = {}
priorities = {}
if count > 0:
# Sample all documents to get metadata distribution
all_results = self.collection.get(include=["metadatas"])
for metadata in all_results["metadatas"]:
cat = metadata.get("category", "unknown")
pri = metadata.get("priority", "unknown")
categories[cat] = categories.get(cat, 0) + 1
priorities[pri] = priorities.get(pri, 0) + 1
return {
"total_chunks": count,
"categories": categories,
"priorities": priorities,
"collection_name": self.COLLECTION_NAME,
"persist_directory": self.persist_directory,
}
def clear(self):
"""Clear all data from the collection."""
self.client.delete_collection(self.COLLECTION_NAME)
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def delete_by_source(self, source: str) -> int:
"""Delete all chunks from a specific source.
Args:
source: Source filename to delete
Returns:
Number of chunks deleted
"""
# Get IDs of chunks from this source
results = self.collection.get(
where={"source": source},
include=[],
)
if results["ids"]:
self.collection.delete(ids=results["ids"])
return len(results["ids"])
return 0