Spaces:
Running
Running
File size: 9,902 Bytes
641b53a 0d03152 9cfa1a8 641b53a d5149c9 9cfa1a8 641b53a 10ea2c4 d5149c9 10ea2c4 641b53a 9cfa1a8 0d03152 d5149c9 641b53a 9cfa1a8 641b53a 9cfa1a8 0d89edc 641b53a 0d03152 d5149c9 b1a23d2 d5149c9 b1a23d2 641b53a d5149c9 0d03152 d5149c9 0d03152 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a 9cfa1a8 d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 0315b16 641b53a 0315b16 641b53a 9cfa1a8 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a d5149c9 641b53a 0315b16 641b53a 0315b16 641b53a 0315b16 641b53a 9cfa1a8 641b53a 0315b16 641b53a 0315b16 641b53a 0315b16 d5149c9 0315b16 641b53a 0315b16 641b53a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | '''
- ChromaDB for development
- Qdrant for production scaling
- Embedding generation via Ollama
- Batch operations for efficiency
- Metadata filtering capabilities
'''
import logging
import os
import time
from typing import Any, Dict, List, Optional, Sequence
import chromadb
import ollama
from src.utils.config import EmbeddingProfile
# qdrant_client is imported lazily inside _init_qdrant() so the module can be
# imported on environments where qdrant-client is not installed (e.g. HF Spaces
# running in Chroma-only / dev mode).
logger = logging.getLogger(__name__)
BATCH_SIZE = 100
class VectorDatabase:
def __init__(
self,
mode: str = "dev",
qdrant_host: str = "localhost",
qdrant_port: int = 6333,
chroma_path: str = "./chroma_db",
embedding_profile_name: str = "ollama_nomic",
embedding_profile: Optional[EmbeddingProfile] = None,
):
self.mode = mode
self._chroma_path = chroma_path
self._qdrant_host = qdrant_host
self._qdrant_port = qdrant_port
self._chroma_client: Optional[chromadb.ClientAPI] = None
self._qdrant_client: Optional[Any] = None # QdrantClient, imported lazily
self._ollama_client = ollama.Client(host=self._resolve_ollama_host())
self._st_model_cache: Dict[str, Any] = {}
self.embedding_profile_name = embedding_profile_name
self.embedding_profile = embedding_profile or EmbeddingProfile(
provider=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(),
framework=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(),
model="nomic-embed-text",
dimension=768,
)
# --- client accessors (lazy init) ---
@property
def chroma_client(self) -> chromadb.ClientAPI:
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(path=self._chroma_path)
logger.info("ChromaDB initialized at %s", self._chroma_path)
return self._chroma_client
@property
def qdrant_client(self) -> Any:
if self._qdrant_client is None:
from qdrant_client import QdrantClient # noqa: PLC0415
self._qdrant_client = QdrantClient(host=self._qdrant_host, port=self._qdrant_port)
logger.info("Qdrant initialized at %s:%s", self._qdrant_host, self._qdrant_port)
return self._qdrant_client
# --- embedding ---
@staticmethod
def _resolve_ollama_host() -> str:
return (
os.getenv("OLLAMA_BASE_URL")
or os.getenv("OLLAMA_HOST")
or "http://localhost:11434"
)
def _generate_st_embedding(self, text: str, model_name: str) -> List[float]:
if model_name not in self._st_model_cache:
from sentence_transformers import SentenceTransformer # noqa: PLC0415
self._st_model_cache[model_name] = SentenceTransformer(model_name)
logger.info("SentenceTransformer loaded: %s", model_name)
return self._st_model_cache[model_name].encode(text, show_progress_bar=False).tolist()
def generate_embedding(self, text: str) -> List[float]:
provider = self.embedding_profile.provider.strip().lower()
if provider == "sentence_transformers":
return self._generate_st_embedding(text, self.embedding_profile.model)
attempts = 3
last_error: Exception | None = None
for idx in range(attempts):
try:
response = self._ollama_client.embeddings(model=self.embedding_profile.model, prompt=text)
return response["embedding"] # type: ignore[return-value]
except Exception as exc:
last_error = exc
if idx == attempts - 1:
raise
time.sleep(0.35 * (idx + 1))
if last_error is not None:
raise last_error
raise RuntimeError("Unexpected Ollama embedding retry state")
def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
return [self.generate_embedding(t) for t in texts]
# --- collection management ---
def collection_name_for_profile(self, base_collection_name: str) -> str:
default_aliases = {"documents"}
if self.embedding_profile_name in {"", "ollama_nomic"} and base_collection_name in default_aliases:
return base_collection_name
return f"{base_collection_name}__{self.embedding_profile_name}"
def _assert_embedding_dimension(self, embedding: List[float]) -> None:
expected = int(self.embedding_profile.dimension)
if len(embedding) != expected:
raise ValueError(
f"Embedding dimension mismatch for profile {self.embedding_profile_name!r}: "
f"expected {expected}, got {len(embedding)}"
)
def create_collection(self, collection_name: str) -> None:
profile_collection_name = self.collection_name_for_profile(collection_name)
if self.mode == "dev":
self.chroma_client.get_or_create_collection(name=profile_collection_name)
logger.info("ChromaDB collection %r ready", profile_collection_name)
else:
from qdrant_client.http.models import Distance, VectorParams # noqa: PLC0415
if not self.qdrant_client.collection_exists(profile_collection_name):
self.qdrant_client.create_collection(
collection_name=profile_collection_name,
vectors_config=VectorParams(size=self.embedding_profile.dimension, distance=Distance.COSINE),
)
logger.info("Qdrant collection %r created", profile_collection_name)
# --- insert ---
def add_documents(self, collection_name: str, documents: List[Dict]) -> None:
"""Insert documents with embeddings generated from document['text'].
Each document dict must have 'id' and 'text'; all other keys go into metadata.
"""
profile_collection_name = self.collection_name_for_profile(collection_name)
for batch_start in range(0, len(documents), BATCH_SIZE):
batch = documents[batch_start: batch_start + BATCH_SIZE]
texts = [doc["text"] for doc in batch]
embeddings = self.generate_embeddings_batch(texts)
for emb in embeddings:
self._assert_embedding_dimension(emb)
if self.mode == "dev":
collection = self.chroma_client.get_or_create_collection(name=profile_collection_name)
metadatas = [
({k: v for k, v in doc.items() if k not in ("id", "text")} or None)
for doc in batch
]
collection.upsert( # type: ignore[arg-type]
ids=[str(doc["id"]) for doc in batch],
documents=texts,
embeddings=embeddings, # type: ignore[arg-type]
metadatas=metadatas, # type: ignore[arg-type]
)
else:
from qdrant_client.http.models import PointStruct # noqa: PLC0415
points = [
PointStruct(
id=doc["id"],
vector=embedding,
payload={k: v for k, v in doc.items() if k not in ("id", "text")},
)
for doc, embedding in zip(batch, embeddings)
]
self.qdrant_client.upsert(collection_name=profile_collection_name, points=points)
logger.info("Upserted batch of %d documents into %r", len(batch), profile_collection_name)
# --- query ---
def query_documents(
self,
collection_name: str,
query_text: str,
top_k: int = 5,
filters: Optional[Dict] = None,
) -> List[Dict]:
"""Search for similar documents, optionally filtered by metadata key/value pairs."""
profile_collection_name = self.collection_name_for_profile(collection_name)
query_embedding = self.generate_embedding(query_text)
self._assert_embedding_dimension(query_embedding)
if self.mode == "dev":
collection = self.chroma_client.get_or_create_collection(name=profile_collection_name)
where = ({k: v for k, v in filters.items()} if filters else None)
results = collection.query(
query_embeddings=[query_embedding], # type: ignore[arg-type]
n_results=top_k,
where=where,
)
ids: List[str] = (results["ids"] or [[]])[0]
docs: List[str] = (results["documents"] or [[]])[0]
metas: List[Dict] = (results["metadatas"] or [[]])[0] # type: ignore[assignment]
dists: List[float] = (results["distances"] or [[]])[0]
return [
{"id": id_, "text": doc, "metadata": meta, "distance": dist}
for id_, doc, meta, dist in zip(ids, docs, metas, dists)
]
else:
from qdrant_client.http.models import FieldCondition, Filter, MatchValue # noqa: PLC0415
search_filter: Optional[Any] = None
if filters:
conditions: Sequence[FieldCondition] = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in filters.items()
]
search_filter = Filter(must=list(conditions))
response = self.qdrant_client.query_points(
collection_name=profile_collection_name,
query=query_embedding,
limit=top_k,
query_filter=search_filter,
)
return [
{"id": hit.id, "metadata": hit.payload, "score": hit.score}
for hit in response.points
]
|