File size: 9,902 Bytes
641b53a
 
 
 
 
 
 
 
0d03152
 
9cfa1a8
641b53a
 
 
d5149c9
9cfa1a8
 
 
 
641b53a
 
 
 
 
 
 
10ea2c4
 
 
 
 
 
d5149c9
 
10ea2c4
641b53a
 
 
 
 
9cfa1a8
0d03152
d5149c9
 
 
 
 
 
 
 
641b53a
 
 
 
 
 
 
 
 
 
 
9cfa1a8
641b53a
9cfa1a8
0d89edc
641b53a
 
 
 
 
 
0d03152
 
 
 
 
 
 
 
d5149c9
 
b1a23d2
d5149c9
 
 
b1a23d2
641b53a
d5149c9
 
 
0d03152
 
 
 
d5149c9
0d03152
 
 
 
 
 
 
 
 
641b53a
 
 
 
 
d5149c9
 
 
 
 
 
 
 
 
 
 
 
 
641b53a
 
d5149c9
641b53a
d5149c9
 
641b53a
9cfa1a8
d5149c9
641b53a
d5149c9
 
641b53a
d5149c9
641b53a
 
 
 
 
 
 
 
d5149c9
641b53a
 
 
 
d5149c9
 
641b53a
 
d5149c9
0315b16
 
 
 
 
641b53a
 
0315b16
 
641b53a
 
9cfa1a8
641b53a
 
 
 
 
 
 
 
d5149c9
641b53a
d5149c9
641b53a
 
 
 
 
 
 
 
 
 
 
d5149c9
641b53a
d5149c9
641b53a
 
d5149c9
641b53a
 
0315b16
641b53a
 
 
0315b16
 
 
 
641b53a
 
0315b16
641b53a
 
9cfa1a8
 
641b53a
0315b16
641b53a
 
 
0315b16
641b53a
0315b16
d5149c9
0315b16
641b53a
 
 
 
 
0315b16
641b53a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
'''
- ChromaDB for development
- Qdrant for production scaling
- Embedding generation via Ollama
- Batch operations for efficiency
- Metadata filtering capabilities
'''
import logging
import os
import time
from typing import Any, Dict, List, Optional, Sequence

import chromadb
import ollama
from src.utils.config import EmbeddingProfile

# qdrant_client is imported lazily inside _init_qdrant() so the module can be
# imported on environments where qdrant-client is not installed (e.g. HF Spaces
# running in Chroma-only / dev mode).

logger = logging.getLogger(__name__)

BATCH_SIZE = 100


class VectorDatabase:
    def __init__(
        self,
        mode: str = "dev",
        qdrant_host: str = "localhost",
        qdrant_port: int = 6333,
        chroma_path: str = "./chroma_db",
        embedding_profile_name: str = "ollama_nomic",
        embedding_profile: Optional[EmbeddingProfile] = None,
    ):
        self.mode = mode
        self._chroma_path = chroma_path
        self._qdrant_host = qdrant_host
        self._qdrant_port = qdrant_port
        self._chroma_client: Optional[chromadb.ClientAPI] = None
        self._qdrant_client: Optional[Any] = None  # QdrantClient, imported lazily
        self._ollama_client = ollama.Client(host=self._resolve_ollama_host())
        self._st_model_cache: Dict[str, Any] = {}
        self.embedding_profile_name = embedding_profile_name
        self.embedding_profile = embedding_profile or EmbeddingProfile(
            provider=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(),
            framework=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(),
            model="nomic-embed-text",
            dimension=768,
        )

    # --- client accessors (lazy init) ---

    @property
    def chroma_client(self) -> chromadb.ClientAPI:
        if self._chroma_client is None:
            self._chroma_client = chromadb.PersistentClient(path=self._chroma_path)
            logger.info("ChromaDB initialized at %s", self._chroma_path)
        return self._chroma_client

    @property
    def qdrant_client(self) -> Any:
        if self._qdrant_client is None:
            from qdrant_client import QdrantClient  # noqa: PLC0415

            self._qdrant_client = QdrantClient(host=self._qdrant_host, port=self._qdrant_port)
            logger.info("Qdrant initialized at %s:%s", self._qdrant_host, self._qdrant_port)
        return self._qdrant_client

    # --- embedding ---

    @staticmethod
    def _resolve_ollama_host() -> str:
        return (
            os.getenv("OLLAMA_BASE_URL")
            or os.getenv("OLLAMA_HOST")
            or "http://localhost:11434"
        )

    def _generate_st_embedding(self, text: str, model_name: str) -> List[float]:
        if model_name not in self._st_model_cache:
            from sentence_transformers import SentenceTransformer  # noqa: PLC0415
            self._st_model_cache[model_name] = SentenceTransformer(model_name)
            logger.info("SentenceTransformer loaded: %s", model_name)
        return self._st_model_cache[model_name].encode(text, show_progress_bar=False).tolist()

    def generate_embedding(self, text: str) -> List[float]:
        provider = self.embedding_profile.provider.strip().lower()
        if provider == "sentence_transformers":
            return self._generate_st_embedding(text, self.embedding_profile.model)
        attempts = 3
        last_error: Exception | None = None
        for idx in range(attempts):
            try:
                response = self._ollama_client.embeddings(model=self.embedding_profile.model, prompt=text)
                return response["embedding"]  # type: ignore[return-value]
            except Exception as exc:
                last_error = exc
                if idx == attempts - 1:
                    raise
                time.sleep(0.35 * (idx + 1))
        if last_error is not None:
            raise last_error
        raise RuntimeError("Unexpected Ollama embedding retry state")

    def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
        return [self.generate_embedding(t) for t in texts]

    # --- collection management ---
    def collection_name_for_profile(self, base_collection_name: str) -> str:
        default_aliases = {"documents"}
        if self.embedding_profile_name in {"", "ollama_nomic"} and base_collection_name in default_aliases:
            return base_collection_name
        return f"{base_collection_name}__{self.embedding_profile_name}"

    def _assert_embedding_dimension(self, embedding: List[float]) -> None:
        expected = int(self.embedding_profile.dimension)
        if len(embedding) != expected:
            raise ValueError(
                f"Embedding dimension mismatch for profile {self.embedding_profile_name!r}: "
                f"expected {expected}, got {len(embedding)}"
            )

    def create_collection(self, collection_name: str) -> None:
        profile_collection_name = self.collection_name_for_profile(collection_name)
        if self.mode == "dev":
            self.chroma_client.get_or_create_collection(name=profile_collection_name)
            logger.info("ChromaDB collection %r ready", profile_collection_name)
        else:
            from qdrant_client.http.models import Distance, VectorParams  # noqa: PLC0415
            if not self.qdrant_client.collection_exists(profile_collection_name):
                self.qdrant_client.create_collection(
                    collection_name=profile_collection_name,
                    vectors_config=VectorParams(size=self.embedding_profile.dimension, distance=Distance.COSINE),
                )
                logger.info("Qdrant collection %r created", profile_collection_name)

    # --- insert ---

    def add_documents(self, collection_name: str, documents: List[Dict]) -> None:
        """Insert documents with embeddings generated from document['text'].

        Each document dict must have 'id' and 'text'; all other keys go into metadata.
        """
        profile_collection_name = self.collection_name_for_profile(collection_name)
        for batch_start in range(0, len(documents), BATCH_SIZE):
            batch = documents[batch_start: batch_start + BATCH_SIZE]
            texts = [doc["text"] for doc in batch]
            embeddings = self.generate_embeddings_batch(texts)
            for emb in embeddings:
                self._assert_embedding_dimension(emb)

            if self.mode == "dev":
                collection = self.chroma_client.get_or_create_collection(name=profile_collection_name)
                metadatas = [
                    ({k: v for k, v in doc.items() if k not in ("id", "text")} or None)
                    for doc in batch
                ]
                collection.upsert(  # type: ignore[arg-type]
                    ids=[str(doc["id"]) for doc in batch],
                    documents=texts,
                    embeddings=embeddings,  # type: ignore[arg-type]
                    metadatas=metadatas,  # type: ignore[arg-type]
                )
            else:
                from qdrant_client.http.models import PointStruct  # noqa: PLC0415
                points = [
                    PointStruct(
                        id=doc["id"],
                        vector=embedding,
                        payload={k: v for k, v in doc.items() if k not in ("id", "text")},
                    )
                    for doc, embedding in zip(batch, embeddings)
                ]
                self.qdrant_client.upsert(collection_name=profile_collection_name, points=points)

            logger.info("Upserted batch of %d documents into %r", len(batch), profile_collection_name)

    # --- query ---

    def query_documents(
        self,
        collection_name: str,
        query_text: str,
        top_k: int = 5,
        filters: Optional[Dict] = None,
    ) -> List[Dict]:
        """Search for similar documents, optionally filtered by metadata key/value pairs."""
        profile_collection_name = self.collection_name_for_profile(collection_name)
        query_embedding = self.generate_embedding(query_text)
        self._assert_embedding_dimension(query_embedding)

        if self.mode == "dev":
            collection = self.chroma_client.get_or_create_collection(name=profile_collection_name)
            where = ({k: v for k, v in filters.items()} if filters else None)
            results = collection.query(
                query_embeddings=[query_embedding],  # type: ignore[arg-type]
                n_results=top_k,
                where=where,
            )
            ids: List[str] = (results["ids"] or [[]])[0]
            docs: List[str] = (results["documents"] or [[]])[0]
            metas: List[Dict] = (results["metadatas"] or [[]])[0]  # type: ignore[assignment]
            dists: List[float] = (results["distances"] or [[]])[0]
            return [
                {"id": id_, "text": doc, "metadata": meta, "distance": dist}
                for id_, doc, meta, dist in zip(ids, docs, metas, dists)
            ]
        else:
            from qdrant_client.http.models import FieldCondition, Filter, MatchValue  # noqa: PLC0415
            search_filter: Optional[Any] = None
            if filters:
                conditions: Sequence[FieldCondition] = [
                    FieldCondition(key=k, match=MatchValue(value=v))
                    for k, v in filters.items()
                ]
                search_filter = Filter(must=list(conditions))

            response = self.qdrant_client.query_points(
                collection_name=profile_collection_name,
                query=query_embedding,
                limit=top_k,
                query_filter=search_filter,
            )
            return [
                {"id": hit.id, "metadata": hit.payload, "score": hit.score}
                for hit in response.points
            ]