"""Qdrant vector store for dense retrieval.""" import hashlib import json import logging from qdrant_client import QdrantClient from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams from src.models import ChunkStrategy, DocumentChunk, QueryResult logger = logging.getLogger(__name__) def _payload_to_chunk(payload: dict) -> DocumentChunk: """Convert a Qdrant payload dict to a DocumentChunk. Args: payload: Qdrant point payload. Returns: DocumentChunk reconstructed from the payload. """ return DocumentChunk( chunk_id=payload["chunk_id"], document_id=payload["document_id"], text=payload["text"], metadata=json.loads(payload["metadata"]), strategy=ChunkStrategy(payload["strategy"]), ) class VectorStore: """Manages document storage and dense retrieval via Qdrant.""" def __init__(self, path: str, collection_name: str, dimension: int, url: str = "") -> None: """Initialize the Qdrant vector store. Args: path: File system path for Qdrant local storage (used when *url* is empty). collection_name: Name of the Qdrant collection. dimension: Embedding vector dimension. url: Optional Qdrant server URL. When provided, connects over HTTP instead of using local file storage. """ self._collection_name = collection_name if url: self._client = QdrantClient(url=url) else: self._client = QdrantClient(path=path) existing = [c.name for c in self._client.get_collections().collections] if collection_name not in existing: self._client.create_collection( collection_name=collection_name, vectors_config=VectorParams(size=dimension, distance=Distance.COSINE), ) logger.info("Created Qdrant collection '%s' (dim=%d)", collection_name, dimension) else: logger.info("Using existing Qdrant collection '%s'", collection_name) def add_chunks(self, chunks: list[DocumentChunk], embeddings: list[list[float]]) -> None: """Index document chunks with their embeddings. Args: chunks: List of document chunks to store. embeddings: Corresponding embedding vectors. Raises: ValueError: If chunks and embeddings have different lengths. """ if len(chunks) != len(embeddings): raise ValueError( f"chunks and embeddings length mismatch: {len(chunks)} vs {len(embeddings)}" ) if not chunks: return points = [ PointStruct( id=int(hashlib.sha256(chunk.chunk_id.encode()).hexdigest()[:15], 16), vector=embedding, payload={ "chunk_id": chunk.chunk_id, "document_id": chunk.document_id, "text": chunk.text, "metadata": json.dumps(chunk.metadata), "strategy": chunk.strategy.value, }, ) for chunk, embedding in zip(chunks, embeddings) ] self._client.upsert(collection_name=self._collection_name, points=points) logger.info("Indexed %d chunks into '%s'", len(points), self._collection_name) def search(self, query_embedding: list[float], top_k: int) -> list[QueryResult]: """Search for the most similar chunks by vector similarity. Args: query_embedding: The query embedding vector. top_k: Number of top results to return. Returns: List of QueryResult objects sorted by relevance. """ hits = self._client.query_points( collection_name=self._collection_name, query=query_embedding, limit=top_k, ).points results: list[QueryResult] = [ QueryResult(chunk=_payload_to_chunk(hit.payload), score=hit.score, source="dense") for hit in hits ] logger.debug("Dense search returned %d results", len(results)) return results def get_all_chunks(self) -> list[DocumentChunk]: """Retrieve all document chunks stored in the collection. Returns: List of all DocumentChunk objects in the collection. """ collection_info = self._client.get_collection(self._collection_name) total = collection_info.points_count if not total: return [] records, _offset = self._client.scroll( collection_name=self._collection_name, limit=total, with_payload=True, with_vectors=False, ) chunks = [_payload_to_chunk(record.payload) for record in records] logger.info("Loaded %d chunks from collection '%s'", len(chunks), self._collection_name) return chunks def list_document_ids(self) -> list[str]: """Return a sorted list of unique document IDs in the collection. Returns: Sorted list of document ID strings. """ all_chunks = self.get_all_chunks() ids = sorted({chunk.document_id for chunk in all_chunks}) logger.debug("Found %d unique document IDs", len(ids)) return ids def get_chunks_by_document_id(self, document_id: str) -> list[DocumentChunk]: """Retrieve all chunks belonging to a specific document. Uses a Qdrant payload filter to avoid loading the full collection. Args: document_id: The document identifier to filter by. Returns: List of DocumentChunk objects for that document, in storage order. """ records, _offset = self._client.scroll( collection_name=self._collection_name, scroll_filter=Filter( must=[FieldCondition(key="document_id", match=MatchValue(value=document_id))] ), limit=10_000, with_payload=True, with_vectors=False, ) chunks = [_payload_to_chunk(record.payload) for record in records] logger.debug( "Fetched %d chunks for document '%s'", len(chunks), document_id ) return chunks def delete_collection(self) -> None: """Delete the entire collection from the store.""" self._client.delete_collection(collection_name=self._collection_name) logger.info("Deleted Qdrant collection '%s'", self._collection_name)