| """ |
| qdrant_store.py — Qdrant vector store abstraction. |
| |
| Supports two modes (controlled by config): |
| - "local": embedded Qdrant with on-disk persistence (no server needed). |
| - "remote": connects to a Qdrant server via host:port (or cloud with API key). |
| |
| The store handles collection lifecycle (create / recreate), upserting points |
| with payloads, and searching. It deliberately does NOT embed texts — that |
| responsibility stays with the embedding module, keeping concerns separated. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from typing import Any, Sequence |
|
|
| import numpy as np |
| from qdrant_client import QdrantClient |
| from qdrant_client.models import ( |
| Distance, |
| PointStruct, |
| VectorParams, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _DISTANCE_MAP = { |
| "cosine": Distance.COSINE, |
| "euclid": Distance.EUCLID, |
| "dot": Distance.DOT, |
| } |
|
|
|
|
| class QdrantStore: |
| """ |
| Thin abstraction over qdrant-client. |
| |
| Parameters |
| ---------- |
| mode : str |
| "local" (embedded, file-backed) or "remote" (server). |
| local_path : str | None |
| Disk path for local mode. |
| host : str |
| Qdrant server host (remote mode). |
| port : int |
| Qdrant server port (remote mode). |
| api_key : str | None |
| API key for Qdrant Cloud (remote mode). |
| collection_name : str |
| Name of the vector collection. |
| distance : str |
| Distance metric ("cosine", "euclid", "dot"). |
| vector_dim : int |
| Dimensionality of vectors to store. |
| """ |
|
|
| def __init__( |
| self, |
| mode: str = "local", |
| local_path: str | None = None, |
| url: str | None = None, |
| host: str = "localhost", |
| port: int = 6333, |
| api_key: str | None = None, |
| collection_name: str = "documents", |
| distance: str = "cosine", |
| vector_dim: int = 384, |
| ) -> None: |
| self.collection_name = collection_name |
| self.vector_dim = vector_dim |
| self.distance = _DISTANCE_MAP.get(distance, Distance.COSINE) |
|
|
| if mode == "local": |
| logger.info("Connecting to Qdrant in local/embedded mode: %s", local_path) |
| self._client = QdrantClient(path=local_path) |
| elif url: |
| logger.info("Connecting to Qdrant Cloud at %s", url) |
| self._client = QdrantClient(url=url, api_key=api_key) |
| else: |
| logger.info("Connecting to Qdrant server at %s:%d", host, port) |
| self._client = QdrantClient(host=host, port=port, api_key=api_key) |
|
|
| |
| |
| |
|
|
| def ensure_collection(self, recreate: bool = False) -> None: |
| """ |
| Create the collection if it doesn't exist. |
| If *recreate* is True, drop and recreate it (useful during re-indexing). |
| """ |
| exists = self._client.collection_exists(self.collection_name) |
|
|
| if exists and recreate: |
| logger.warning("Recreating collection '%s'", self.collection_name) |
| self._client.delete_collection(self.collection_name) |
| exists = False |
|
|
| if not exists: |
| self._client.create_collection( |
| collection_name=self.collection_name, |
| vectors_config=VectorParams( |
| size=self.vector_dim, |
| distance=self.distance, |
| ), |
| ) |
| logger.info( |
| "Created collection '%s' (dim=%d, distance=%s)", |
| self.collection_name, self.vector_dim, self.distance, |
| ) |
| else: |
| logger.info("Collection '%s' already exists", self.collection_name) |
|
|
| |
| |
| |
|
|
| def upsert_batch( |
| self, |
| ids: Sequence[str], |
| vectors: np.ndarray, |
| payloads: Sequence[dict[str, Any]], |
| ) -> None: |
| """ |
| Insert or update a batch of points. |
| |
| Parameters |
| ---------- |
| ids : list of str |
| Unique point identifiers (we use chunk_id). |
| vectors : np.ndarray, shape (N, D) |
| Embedding vectors. |
| payloads : list of dict |
| Metadata payloads (one per point). |
| """ |
| points = [ |
| PointStruct( |
| id=idx, |
| vector=vec.tolist(), |
| payload=pay, |
| ) |
| for idx, (vec, pay) in enumerate(zip(vectors, payloads)) |
| ] |
|
|
| |
| |
| |
| import hashlib |
| points = [] |
| for cid, vec, pay in zip(ids, vectors, payloads): |
| |
| h = int(hashlib.sha256(cid.encode()).hexdigest()[:16], 16) |
| points.append( |
| PointStruct(id=h, vector=vec.tolist(), payload=pay) |
| ) |
|
|
| self._client.upsert( |
| collection_name=self.collection_name, |
| points=points, |
| ) |
| logger.debug("Upserted %d points into '%s'", len(points), self.collection_name) |
|
|
| |
| |
| |
|
|
| def search( |
| self, |
| query_vector: np.ndarray, |
| top_k: int = 5, |
| score_threshold: float | None = None, |
| query_filter: Any = None, |
| ) -> list[dict[str, Any]]: |
| """ |
| Search the collection for nearest neighbours. |
| |
| Returns a list of dicts with keys: id, score, payload. |
| The *query_filter* param is a placeholder for Qdrant filter objects |
| (to be used when we add metadata filtering later). |
| """ |
| results = self._client.query_points( |
| collection_name=self.collection_name, |
| query=query_vector.tolist(), |
| limit=top_k, |
| score_threshold=score_threshold, |
| query_filter=query_filter, |
| ).points |
|
|
| hits: list[dict[str, Any]] = [] |
| for r in results: |
| hits.append({ |
| "id": r.id, |
| "score": r.score, |
| "payload": r.payload, |
| }) |
| return hits |
|
|
| |
| |
| |
|
|
| def collection_info(self) -> dict[str, Any]: |
| info = self._client.get_collection(self.collection_name) |
| result: dict[str, Any] = { |
| "name": self.collection_name, |
| "points_count": info.points_count, |
| "status": str(info.status), |
| } |
| |
| if hasattr(info, "vectors_count"): |
| result["vectors_count"] = info.vectors_count |
| return result |
|
|
| def close(self) -> None: |
| """Close the client connection (relevant for local mode flush).""" |
| self._client.close() |
|
|
|
|
| |
| |
| |
|
|
| def build_qdrant_store(cfg: dict, vector_dim: int | None = None) -> QdrantStore: |
| """ |
| Build a QdrantStore from the config dict. |
| |
| If *vector_dim* is not given, uses the value from cfg["embeddings"]["dimension"]. |
| """ |
| q = cfg["qdrant"] |
| dim = vector_dim or cfg["embeddings"]["dimension"] |
| return QdrantStore( |
| mode=q.get("mode", "local"), |
| local_path=q.get("local_path"), |
| host=q.get("host", "localhost"), |
| port=q.get("port", 6333), |
| api_key=q.get("api_key"), |
| collection_name=q.get("collection_name", "documents"), |
| distance=q.get("distance", "cosine"), |
| vector_dim=dim, |
| ) |
|
|