File size: 8,581 Bytes
88bdcff
 
 
 
 
 
 
f3ebc82
88bdcff
 
 
 
 
 
 
 
 
f3ebc82
 
88bdcff
 
 
 
 
706520f
88bdcff
 
706520f
88bdcff
 
 
 
 
 
 
 
 
f3ebc82
88bdcff
f3ebc82
 
88bdcff
 
 
 
 
 
 
 
 
 
 
f3ebc82
 
 
 
 
88bdcff
 
 
5f0db1e
 
88bdcff
5f0db1e
333c083
f3ebc82
5f0db1e
88bdcff
 
706520f
88bdcff
 
5f0db1e
 
88bdcff
5f0db1e
88bdcff
333c083
5f0db1e
88bdcff
 
 
5f0db1e
 
 
 
 
88bdcff
 
5f0db1e
88bdcff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3ebc82
88bdcff
 
 
 
 
f3ebc82
88bdcff
 
 
 
 
 
f3ebc82
 
88bdcff
 
 
 
 
 
f3ebc82
88bdcff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""ChromaDB vector store for FDAM knowledge base.

Provides embedding and storage with metadata support.
Uses mock embeddings when MOCK_MODELS=true for local development.
"""

import hashlib
import logging
from typing import Optional
from pathlib import Path

import chromadb
from chromadb.config import Settings

from config.settings import settings
from .chunker import Chunk

logger = logging.getLogger(__name__)


class MockEmbeddingFunction:
    """Mock embedding function for local development.

    Generates deterministic pseudo-embeddings based on text hash.
    Produces 2048-dimensional vectors (matches Qwen3-VL-Embedding-2B).
    """

    EMBEDDING_DIM = 2048  # Per Qwen3-VL-Embedding-2B hidden_size

    def __call__(self, input: list[str]) -> list[list[float]]:
        """Generate mock embeddings for a list of texts."""
        return [self._embed_text(text) for text in input]

    def _embed_text(self, text: str) -> list[float]:
        """Generate a deterministic pseudo-embedding from text.

        Uses SHA-256 hash expanded to fill embedding dimensions.
        L2 normalized to match real model output.
        """
        import math

        # Hash the text
        text_hash = hashlib.sha256(text.encode("utf-8")).digest()

        # Expand hash to fill embedding dimensions
        embedding = []
        for i in range(self.EMBEDDING_DIM):
            # Use hash bytes cyclically, normalized to [-1, 1]
            byte_val = text_hash[i % len(text_hash)]
            normalized = (byte_val / 127.5) - 1.0
            embedding.append(normalized)

        # L2 normalize (matching real model behavior)
        norm = math.sqrt(sum(x * x for x in embedding))
        if norm > 0:
            embedding = [x / norm for x in embedding]

        return embedding


class SharedEmbeddingFunction:
    """Embedding function that uses the shared model from RealModelStack.

    This avoids loading a duplicate embedding model - instead uses the
    model already loaded by the pipeline at startup.

    For ChromaDB compatibility, this wraps the model stack's embedding model.
    """

    EMBEDDING_DIM = 2048  # Per Qwen3-VL-Embedding-2B hidden_size

    def __call__(self, input: list[str]) -> list[list[float]]:
        """Generate embeddings using the shared model from model stack."""
        from models.loader import get_models

        model_stack = get_models()

        # Use the shared embedding model (always loaded at startup)
        return model_stack.embedding.embed_batch(input)


def get_embedding_function():
    """Get appropriate embedding function based on settings.

    For real models, uses SharedEmbeddingFunction which wraps the
    model stack's embedding model (no duplicate loading).
    """
    if settings.mock_models:
        return MockEmbeddingFunction()
    return SharedEmbeddingFunction()


class ChromaVectorStore:
    """ChromaDB-based vector store for FDAM knowledge base."""

    COLLECTION_NAME = "fdam_knowledge_base"

    def __init__(
        self,
        persist_directory: Optional[str] = None,
        embedding_function=None,
    ):
        """Initialize vector store.

        Args:
            persist_directory: Directory for ChromaDB persistence.
                             If None, uses in-memory storage.
            embedding_function: Custom embedding function.
                              If None, uses appropriate default.
        """
        self.persist_directory = persist_directory

        # Initialize ChromaDB client
        if persist_directory:
            persist_path = Path(persist_directory)
            persist_path.mkdir(parents=True, exist_ok=True)
            logger.debug(f"ChromaDB: using persistent storage at {persist_path}")
            self.client = chromadb.PersistentClient(
                path=str(persist_path),
                settings=Settings(anonymized_telemetry=False),
            )
        else:
            logger.debug("ChromaDB: using in-memory storage")
            self.client = chromadb.Client(
                settings=Settings(anonymized_telemetry=False),
            )

        # Set up embedding function
        self.embedding_function = embedding_function or get_embedding_function()
        embed_type = "mock" if settings.mock_models else "real"
        logger.debug(f"ChromaDB: using {embed_type} embeddings")

        # Get or create collection
        self.collection = self.client.get_or_create_collection(
            name=self.COLLECTION_NAME,
            metadata={"hnsw:space": "cosine"},
        )
        logger.info(f"ChromaDB collection '{self.COLLECTION_NAME}' ready: {self.collection.count()} chunks")

    def add_chunks(self, chunks: list[Chunk]) -> int:
        """Add chunks to the vector store.

        Args:
            chunks: List of Chunk objects to add

        Returns:
            Number of chunks added
        """
        if not chunks:
            return 0

        ids = [chunk.id for chunk in chunks]
        documents = [chunk.text for chunk in chunks]
        metadatas = [chunk.to_metadata() for chunk in chunks]

        # Generate embeddings
        embeddings = self.embedding_function(documents)

        # Add to collection
        self.collection.add(
            ids=ids,
            embeddings=embeddings,
            documents=documents,
            metadatas=metadatas,
        )

        return len(chunks)

    def query(
        self,
        query_text: str,
        n_results: int = 5,
        where: Optional[dict] = None,
        where_document: Optional[dict] = None,
    ) -> list[dict]:
        """Query the vector store.

        Args:
            query_text: Query text to search for
            n_results: Number of results to return
            where: Metadata filter (e.g., {"priority": "primary"})
            where_document: Document content filter

        Returns:
            List of result dicts with keys: id, document, metadata, distance
        """
        # Generate query embedding
        query_embedding = self.embedding_function([query_text])[0]

        # Query collection
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where=where,
            where_document=where_document,
            include=["documents", "metadatas", "distances"],
        )

        # Format results
        formatted = []
        if results["ids"] and results["ids"][0]:
            for i, chunk_id in enumerate(results["ids"][0]):
                formatted.append(
                    {
                        "id": chunk_id,
                        "document": results["documents"][0][i],
                        "metadata": results["metadatas"][0][i],
                        "distance": results["distances"][0][i],
                    }
                )

        return formatted

    def get_stats(self) -> dict:
        """Get collection statistics."""
        count = self.collection.count()

        # Get category distribution
        categories = {}
        priorities = {}

        if count > 0:
            # Sample all documents to get metadata distribution
            all_results = self.collection.get(include=["metadatas"])
            for metadata in all_results["metadatas"]:
                cat = metadata.get("category", "unknown")
                pri = metadata.get("priority", "unknown")
                categories[cat] = categories.get(cat, 0) + 1
                priorities[pri] = priorities.get(pri, 0) + 1

        return {
            "total_chunks": count,
            "categories": categories,
            "priorities": priorities,
            "collection_name": self.COLLECTION_NAME,
            "persist_directory": self.persist_directory,
        }

    def clear(self):
        """Clear all data from the collection."""
        self.client.delete_collection(self.COLLECTION_NAME)
        self.collection = self.client.get_or_create_collection(
            name=self.COLLECTION_NAME,
            metadata={"hnsw:space": "cosine"},
        )

    def delete_by_source(self, source: str) -> int:
        """Delete all chunks from a specific source.

        Args:
            source: Source filename to delete

        Returns:
            Number of chunks deleted
        """
        # Get IDs of chunks from this source
        results = self.collection.get(
            where={"source": source},
            include=[],
        )

        if results["ids"]:
            self.collection.delete(ids=results["ids"])
            return len(results["ids"])

        return 0