Spaces:
Paused
Paused
File size: 8,581 Bytes
88bdcff f3ebc82 88bdcff f3ebc82 88bdcff 706520f 88bdcff 706520f 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff 5f0db1e 88bdcff 5f0db1e 333c083 f3ebc82 5f0db1e 88bdcff 706520f 88bdcff 5f0db1e 88bdcff 5f0db1e 88bdcff 333c083 5f0db1e 88bdcff 5f0db1e 88bdcff 5f0db1e 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
"""ChromaDB vector store for FDAM knowledge base.
Provides embedding and storage with metadata support.
Uses mock embeddings when MOCK_MODELS=true for local development.
"""
import hashlib
import logging
from typing import Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from config.settings import settings
from .chunker import Chunk
logger = logging.getLogger(__name__)
class MockEmbeddingFunction:
"""Mock embedding function for local development.
Generates deterministic pseudo-embeddings based on text hash.
Produces 2048-dimensional vectors (matches Qwen3-VL-Embedding-2B).
"""
EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate mock embeddings for a list of texts."""
return [self._embed_text(text) for text in input]
def _embed_text(self, text: str) -> list[float]:
"""Generate a deterministic pseudo-embedding from text.
Uses SHA-256 hash expanded to fill embedding dimensions.
L2 normalized to match real model output.
"""
import math
# Hash the text
text_hash = hashlib.sha256(text.encode("utf-8")).digest()
# Expand hash to fill embedding dimensions
embedding = []
for i in range(self.EMBEDDING_DIM):
# Use hash bytes cyclically, normalized to [-1, 1]
byte_val = text_hash[i % len(text_hash)]
normalized = (byte_val / 127.5) - 1.0
embedding.append(normalized)
# L2 normalize (matching real model behavior)
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
class SharedEmbeddingFunction:
"""Embedding function that uses the shared model from RealModelStack.
This avoids loading a duplicate embedding model - instead uses the
model already loaded by the pipeline at startup.
For ChromaDB compatibility, this wraps the model stack's embedding model.
"""
EMBEDDING_DIM = 2048 # Per Qwen3-VL-Embedding-2B hidden_size
def __call__(self, input: list[str]) -> list[list[float]]:
"""Generate embeddings using the shared model from model stack."""
from models.loader import get_models
model_stack = get_models()
# Use the shared embedding model (always loaded at startup)
return model_stack.embedding.embed_batch(input)
def get_embedding_function():
"""Get appropriate embedding function based on settings.
For real models, uses SharedEmbeddingFunction which wraps the
model stack's embedding model (no duplicate loading).
"""
if settings.mock_models:
return MockEmbeddingFunction()
return SharedEmbeddingFunction()
class ChromaVectorStore:
"""ChromaDB-based vector store for FDAM knowledge base."""
COLLECTION_NAME = "fdam_knowledge_base"
def __init__(
self,
persist_directory: Optional[str] = None,
embedding_function=None,
):
"""Initialize vector store.
Args:
persist_directory: Directory for ChromaDB persistence.
If None, uses in-memory storage.
embedding_function: Custom embedding function.
If None, uses appropriate default.
"""
self.persist_directory = persist_directory
# Initialize ChromaDB client
if persist_directory:
persist_path = Path(persist_directory)
persist_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"ChromaDB: using persistent storage at {persist_path}")
self.client = chromadb.PersistentClient(
path=str(persist_path),
settings=Settings(anonymized_telemetry=False),
)
else:
logger.debug("ChromaDB: using in-memory storage")
self.client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
# Set up embedding function
self.embedding_function = embedding_function or get_embedding_function()
embed_type = "mock" if settings.mock_models else "real"
logger.debug(f"ChromaDB: using {embed_type} embeddings")
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks")
def add_chunks(self, chunks: list[Chunk]) -> int:
"""Add chunks to the vector store.
Args:
chunks: List of Chunk objects to add
Returns:
Number of chunks added
"""
if not chunks:
return 0
ids = [chunk.id for chunk in chunks]
documents = [chunk.text for chunk in chunks]
metadatas = [chunk.to_metadata() for chunk in chunks]
# Generate embeddings
embeddings = self.embedding_function(documents)
# Add to collection
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
return len(chunks)
def query(
self,
query_text: str,
n_results: int = 5,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
) -> list[dict]:
"""Query the vector store.
Args:
query_text: Query text to search for
n_results: Number of results to return
where: Metadata filter (e.g., {"priority": "primary"})
where_document: Document content filter
Returns:
List of result dicts with keys: id, document, metadata, distance
"""
# Generate query embedding
query_embedding = self.embedding_function([query_text])[0]
# Query collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
where_document=where_document,
include=["documents", "metadatas", "distances"],
)
# Format results
formatted = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
formatted.append(
{
"id": chunk_id,
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
return formatted
def get_stats(self) -> dict:
"""Get collection statistics."""
count = self.collection.count()
# Get category distribution
categories = {}
priorities = {}
if count > 0:
# Sample all documents to get metadata distribution
all_results = self.collection.get(include=["metadatas"])
for metadata in all_results["metadatas"]:
cat = metadata.get("category", "unknown")
pri = metadata.get("priority", "unknown")
categories[cat] = categories.get(cat, 0) + 1
priorities[pri] = priorities.get(pri, 0) + 1
return {
"total_chunks": count,
"categories": categories,
"priorities": priorities,
"collection_name": self.COLLECTION_NAME,
"persist_directory": self.persist_directory,
}
def clear(self):
"""Clear all data from the collection."""
self.client.delete_collection(self.COLLECTION_NAME)
self.collection = self.client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def delete_by_source(self, source: str) -> int:
"""Delete all chunks from a specific source.
Args:
source: Source filename to delete
Returns:
Number of chunks deleted
"""
# Get IDs of chunks from this source
results = self.collection.get(
where={"source": source},
include=[],
)
if results["ids"]:
self.collection.delete(ids=results["ids"])
return len(results["ids"])
return 0
|