Spaces:
Running
Running
| ''' | |
| - ChromaDB for development | |
| - Qdrant for production scaling | |
| - Embedding generation via Ollama | |
| - Batch operations for efficiency | |
| - Metadata filtering capabilities | |
| ''' | |
| import logging | |
| import os | |
| import time | |
| from typing import Any, Dict, List, Optional, Sequence | |
| import chromadb | |
| import ollama | |
| from src.utils.config import EmbeddingProfile | |
| # qdrant_client is imported lazily inside _init_qdrant() so the module can be | |
| # imported on environments where qdrant-client is not installed (e.g. HF Spaces | |
| # running in Chroma-only / dev mode). | |
| logger = logging.getLogger(__name__) | |
| BATCH_SIZE = 100 | |
| class VectorDatabase: | |
| def __init__( | |
| self, | |
| mode: str = "dev", | |
| qdrant_host: str = "localhost", | |
| qdrant_port: int = 6333, | |
| chroma_path: str = "./chroma_db", | |
| embedding_profile_name: str = "ollama_nomic", | |
| embedding_profile: Optional[EmbeddingProfile] = None, | |
| ): | |
| self.mode = mode | |
| self._chroma_path = chroma_path | |
| self._qdrant_host = qdrant_host | |
| self._qdrant_port = qdrant_port | |
| self._chroma_client: Optional[chromadb.ClientAPI] = None | |
| self._qdrant_client: Optional[Any] = None # QdrantClient, imported lazily | |
| self._ollama_client = ollama.Client(host=self._resolve_ollama_host()) | |
| self._st_model_cache: Dict[str, Any] = {} | |
| self.embedding_profile_name = embedding_profile_name | |
| self.embedding_profile = embedding_profile or EmbeddingProfile( | |
| provider=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(), | |
| framework=os.getenv("DOC_EMBEDDING_PROVIDER", "ollama").lower(), | |
| model="nomic-embed-text", | |
| dimension=768, | |
| ) | |
| # --- client accessors (lazy init) --- | |
| def chroma_client(self) -> chromadb.ClientAPI: | |
| if self._chroma_client is None: | |
| self._chroma_client = chromadb.PersistentClient(path=self._chroma_path) | |
| logger.info("ChromaDB initialized at %s", self._chroma_path) | |
| return self._chroma_client | |
| def qdrant_client(self) -> Any: | |
| if self._qdrant_client is None: | |
| from qdrant_client import QdrantClient # noqa: PLC0415 | |
| self._qdrant_client = QdrantClient(host=self._qdrant_host, port=self._qdrant_port) | |
| logger.info("Qdrant initialized at %s:%s", self._qdrant_host, self._qdrant_port) | |
| return self._qdrant_client | |
| # --- embedding --- | |
| def _resolve_ollama_host() -> str: | |
| return ( | |
| os.getenv("OLLAMA_BASE_URL") | |
| or os.getenv("OLLAMA_HOST") | |
| or "http://localhost:11434" | |
| ) | |
| def _generate_st_embedding(self, text: str, model_name: str) -> List[float]: | |
| if model_name not in self._st_model_cache: | |
| from sentence_transformers import SentenceTransformer # noqa: PLC0415 | |
| self._st_model_cache[model_name] = SentenceTransformer(model_name) | |
| logger.info("SentenceTransformer loaded: %s", model_name) | |
| return self._st_model_cache[model_name].encode(text, show_progress_bar=False).tolist() | |
| def generate_embedding(self, text: str) -> List[float]: | |
| provider = self.embedding_profile.provider.strip().lower() | |
| if provider == "sentence_transformers": | |
| return self._generate_st_embedding(text, self.embedding_profile.model) | |
| attempts = 3 | |
| last_error: Exception | None = None | |
| for idx in range(attempts): | |
| try: | |
| response = self._ollama_client.embeddings(model=self.embedding_profile.model, prompt=text) | |
| return response["embedding"] # type: ignore[return-value] | |
| except Exception as exc: | |
| last_error = exc | |
| if idx == attempts - 1: | |
| raise | |
| time.sleep(0.35 * (idx + 1)) | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("Unexpected Ollama embedding retry state") | |
| def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]: | |
| return [self.generate_embedding(t) for t in texts] | |
| # --- collection management --- | |
| def collection_name_for_profile(self, base_collection_name: str) -> str: | |
| default_aliases = {"documents"} | |
| if self.embedding_profile_name in {"", "ollama_nomic"} and base_collection_name in default_aliases: | |
| return base_collection_name | |
| return f"{base_collection_name}__{self.embedding_profile_name}" | |
| def _assert_embedding_dimension(self, embedding: List[float]) -> None: | |
| expected = int(self.embedding_profile.dimension) | |
| if len(embedding) != expected: | |
| raise ValueError( | |
| f"Embedding dimension mismatch for profile {self.embedding_profile_name!r}: " | |
| f"expected {expected}, got {len(embedding)}" | |
| ) | |
| def create_collection(self, collection_name: str) -> None: | |
| profile_collection_name = self.collection_name_for_profile(collection_name) | |
| if self.mode == "dev": | |
| self.chroma_client.get_or_create_collection(name=profile_collection_name) | |
| logger.info("ChromaDB collection %r ready", profile_collection_name) | |
| else: | |
| from qdrant_client.http.models import Distance, VectorParams # noqa: PLC0415 | |
| if not self.qdrant_client.collection_exists(profile_collection_name): | |
| self.qdrant_client.create_collection( | |
| collection_name=profile_collection_name, | |
| vectors_config=VectorParams(size=self.embedding_profile.dimension, distance=Distance.COSINE), | |
| ) | |
| logger.info("Qdrant collection %r created", profile_collection_name) | |
| # --- insert --- | |
| def add_documents(self, collection_name: str, documents: List[Dict]) -> None: | |
| """Insert documents with embeddings generated from document['text']. | |
| Each document dict must have 'id' and 'text'; all other keys go into metadata. | |
| """ | |
| profile_collection_name = self.collection_name_for_profile(collection_name) | |
| for batch_start in range(0, len(documents), BATCH_SIZE): | |
| batch = documents[batch_start: batch_start + BATCH_SIZE] | |
| texts = [doc["text"] for doc in batch] | |
| embeddings = self.generate_embeddings_batch(texts) | |
| for emb in embeddings: | |
| self._assert_embedding_dimension(emb) | |
| if self.mode == "dev": | |
| collection = self.chroma_client.get_or_create_collection(name=profile_collection_name) | |
| metadatas = [ | |
| ({k: v for k, v in doc.items() if k not in ("id", "text")} or None) | |
| for doc in batch | |
| ] | |
| collection.upsert( # type: ignore[arg-type] | |
| ids=[str(doc["id"]) for doc in batch], | |
| documents=texts, | |
| embeddings=embeddings, # type: ignore[arg-type] | |
| metadatas=metadatas, # type: ignore[arg-type] | |
| ) | |
| else: | |
| from qdrant_client.http.models import PointStruct # noqa: PLC0415 | |
| points = [ | |
| PointStruct( | |
| id=doc["id"], | |
| vector=embedding, | |
| payload={k: v for k, v in doc.items() if k not in ("id", "text")}, | |
| ) | |
| for doc, embedding in zip(batch, embeddings) | |
| ] | |
| self.qdrant_client.upsert(collection_name=profile_collection_name, points=points) | |
| logger.info("Upserted batch of %d documents into %r", len(batch), profile_collection_name) | |
| # --- query --- | |
| def query_documents( | |
| self, | |
| collection_name: str, | |
| query_text: str, | |
| top_k: int = 5, | |
| filters: Optional[Dict] = None, | |
| ) -> List[Dict]: | |
| """Search for similar documents, optionally filtered by metadata key/value pairs.""" | |
| profile_collection_name = self.collection_name_for_profile(collection_name) | |
| query_embedding = self.generate_embedding(query_text) | |
| self._assert_embedding_dimension(query_embedding) | |
| if self.mode == "dev": | |
| collection = self.chroma_client.get_or_create_collection(name=profile_collection_name) | |
| where = ({k: v for k, v in filters.items()} if filters else None) | |
| results = collection.query( | |
| query_embeddings=[query_embedding], # type: ignore[arg-type] | |
| n_results=top_k, | |
| where=where, | |
| ) | |
| ids: List[str] = (results["ids"] or [[]])[0] | |
| docs: List[str] = (results["documents"] or [[]])[0] | |
| metas: List[Dict] = (results["metadatas"] or [[]])[0] # type: ignore[assignment] | |
| dists: List[float] = (results["distances"] or [[]])[0] | |
| return [ | |
| {"id": id_, "text": doc, "metadata": meta, "distance": dist} | |
| for id_, doc, meta, dist in zip(ids, docs, metas, dists) | |
| ] | |
| else: | |
| from qdrant_client.http.models import FieldCondition, Filter, MatchValue # noqa: PLC0415 | |
| search_filter: Optional[Any] = None | |
| if filters: | |
| conditions: Sequence[FieldCondition] = [ | |
| FieldCondition(key=k, match=MatchValue(value=v)) | |
| for k, v in filters.items() | |
| ] | |
| search_filter = Filter(must=list(conditions)) | |
| response = self.qdrant_client.query_points( | |
| collection_name=profile_collection_name, | |
| query=query_embedding, | |
| limit=top_k, | |
| query_filter=search_filter, | |
| ) | |
| return [ | |
| {"id": hit.id, "metadata": hit.payload, "score": hit.score} | |
| for hit in response.points | |
| ] | |