SmokeScan / rag /vectorstore.py
KinetoLabs's picture
Fix multi-GPU compatibility issues (6 locations)
d1901ae
raw
history blame
10.8 kB
"""ChromaDB vector store for FDAM knowledge base.
Provides embedding and storage with metadata support.
Uses mock embeddings when MOCK_MODELS=true for local development.
"""
import hashlib
import logging
from typing import Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from config.settings import settings
from .chunker import Chunk
logger = logging.getLogger(__name__)
class MockEmbeddingFunction:
"""Mock embedding function for local development.
Generates deterministic pseudo-embeddings based on text hash.
Produces 4096-dimensional vectors (matches Qwen3-VL-Embedding-8B).
"""
EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate mock embeddings for a list of texts."""
return [self._embed_text(text) for text in input]
def _embed_text(self, text: str) -> list[float]:
"""Generate a deterministic pseudo-embedding from text.
Uses SHA-256 hash expanded to fill embedding dimensions.
L2 normalized to match real model output.
"""
import math
# Hash the text
text_hash = hashlib.sha256(text.encode("utf-8")).digest()
# Expand hash to fill embedding dimensions
embedding = []
for i in range(self.EMBEDDING_DIM):
# Use hash bytes cyclically, normalized to [-1, 1]
byte_val = text_hash[i % len(text_hash)]
normalized = (byte_val / 127.5) - 1.0
embedding.append(normalized)
# L2 normalize (matching real model behavior)
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
class RealEmbeddingFunction:
"""Real embedding function using Qwen3-VL-Embedding-8B.
Uses last-token pooling per official Qwen3-VL-Embedding implementation.
Loaded on-demand when MOCK_MODELS=false.
Reference: https://github.com/QwenLM/Qwen3-VL-Embedding
"""
EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
def __init__(self):
self.model = None
self.tokenizer = None
def _load_model(self):
"""Lazy load the embedding model."""
if self.model is not None:
return
import torch
from transformers import AutoModel, AutoTokenizer
model_name = "Qwen/Qwen3-VL-Embedding-8B"
logger.info(f"Loading embedding model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
self.model.eval()
@staticmethod
def _pooling_last(hidden_state, attention_mask):
"""Extract the last valid token's hidden state.
Official pooling method from Qwen3-VL-Embedding.
Finds the last position where attention_mask == 1 and extracts that token.
"""
import torch
flipped_tensor = attention_mask.flip(dims=[1])
last_one_positions = flipped_tensor.argmax(dim=1)
col = attention_mask.shape[1] - last_one_positions - 1
row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
return hidden_state[row, col]
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate embeddings for a list of texts using last-token pooling."""
self._load_model()
import torch
embeddings = []
with torch.no_grad():
for text in input:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
# Note: With device_map="auto", transformers handles device routing internally
# Do NOT call .to(device) - it breaks distributed models
outputs = self.model(**inputs)
# Use last-token pooling (official Qwen3-VL-Embedding method)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
embedding = self._pooling_last(outputs.last_hidden_state, attention_mask)
else:
# Fallback: use last token if no attention mask
embedding = outputs.last_hidden_state[:, -1, :]
# L2 normalize (per official implementation)
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
embeddings.append(embedding.squeeze().cpu().float().tolist())
return embeddings
def get_embedding_function():
"""Get appropriate embedding function based on settings."""
if settings.mock_models:
return MockEmbeddingFunction()
return RealEmbeddingFunction()
class ChromaVectorStore:
"""ChromaDB-based vector store for FDAM knowledge base."""
COLLECTION_NAME = "fdam_knowledge_base"
def __init__(
self,
persist_directory: Optional[str] = None,
embedding_function=None,
):
"""Initialize vector store.
Args:
persist_directory: Directory for ChromaDB persistence.
If None, uses in-memory storage.
embedding_function: Custom embedding function.
If None, uses appropriate default.
"""
self.persist_directory = persist_directory
# Initialize ChromaDB client
if persist_directory:
persist_path = Path(persist_directory)
persist_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"ChromaDB: using persistent storage at {persist_path}")
self.client = chromadb.PersistentClient(
path=str(persist_path),
settings=Settings(anonymized_telemetry=False),
)
else:
logger.debug("ChromaDB: using in-memory storage")
self.client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
# Set up embedding function
self.embedding_function = embedding_function or get_embedding_function()
embed_type = "mock" if settings.mock_models else "real"
logger.debug(f"ChromaDB: using {embed_type} embeddings")
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks")
def add_chunks(self, chunks: list[Chunk]) -> int:
"""Add chunks to the vector store.
Args:
chunks: List of Chunk objects to add
Returns:
Number of chunks added
"""
if not chunks:
return 0
ids = [chunk.id for chunk in chunks]
documents = [chunk.text for chunk in chunks]
metadatas = [chunk.to_metadata() for chunk in chunks]
# Generate embeddings
embeddings = self.embedding_function(documents)
# Add to collection
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
return len(chunks)
def query(
self,
query_text: str,
n_results: int = 5,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
) -> list[dict]:
"""Query the vector store.
Args:
query_text: Query text to search for
n_results: Number of results to return
where: Metadata filter (e.g., {"priority": "primary"})
where_document: Document content filter
Returns:
List of result dicts with keys: id, document, metadata, distance
"""
# Generate query embedding
query_embedding = self.embedding_function([query_text])[0]
# Query collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
where_document=where_document,
include=["documents", "metadatas", "distances"],
)
# Format results
formatted = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
formatted.append(
{
"id": chunk_id,
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
return formatted
def get_stats(self) -> dict:
"""Get collection statistics."""
count = self.collection.count()
# Get category distribution
categories = {}
priorities = {}
if count > 0:
# Sample all documents to get metadata distribution
all_results = self.collection.get(include=["metadatas"])
for metadata in all_results["metadatas"]:
cat = metadata.get("category", "unknown")
pri = metadata.get("priority", "unknown")
categories[cat] = categories.get(cat, 0) + 1
priorities[pri] = priorities.get(pri, 0) + 1
return {
"total_chunks": count,
"categories": categories,
"priorities": priorities,
"collection_name": self.COLLECTION_NAME,
"persist_directory": self.persist_directory,
}
def clear(self):
"""Clear all data from the collection."""
self.client.delete_collection(self.COLLECTION_NAME)
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def delete_by_source(self, source: str) -> int:
"""Delete all chunks from a specific source.
Args:
source: Source filename to delete
Returns:
Number of chunks deleted
"""
# Get IDs of chunks from this source
results = self.collection.get(
where={"source": source},
include=[],
)
if results["ids"]:
self.collection.delete(ids=results["ids"])
return len(results["ids"])
return 0