| """Embeddings, Qdrant client, collection setup, and vector store.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Iterator |
| from functools import lru_cache |
|
|
| from langchain_qdrant import QdrantVectorStore |
| from qdrant_client import QdrantClient |
| from qdrant_client.http import models as qmodels |
|
|
| from src.config import settings |
| from src.embeddings import get_embeddings |
|
|
| _SCROLL_PAGE_SIZE = 256 |
|
|
| INDEXED_PAYLOAD_FIELDS = { |
| "metadata.document_id": qmodels.PayloadSchemaType.KEYWORD, |
| "metadata.filename": qmodels.PayloadSchemaType.KEYWORD, |
| "metadata.page": qmodels.PayloadSchemaType.INTEGER, |
| } |
|
|
|
|
| def close_client() -> None: |
| if get_client.cache_info().currsize == 0: |
| return |
| client = get_client() |
| client.close() |
| get_client.cache_clear() |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_client() -> QdrantClient: |
| """Return a cached local Qdrant client backed by on-disk storage.""" |
| settings.storage_dir.mkdir(parents=True, exist_ok=True) |
| return QdrantClient(path=str(settings.storage_dir)) |
|
|
|
|
| def ensure_collection(recreate: bool = False, collection_name: str | None = None) -> None: |
| """Create the collection and payload indexes if they do not exist.""" |
| client = get_client() |
| name = collection_name or settings.qdrant_collection |
|
|
| exists = client.collection_exists(name) |
| if exists and recreate: |
| client.delete_collection(name) |
| exists = False |
|
|
| if not exists: |
| dim = len(get_embeddings().embed_query("dimension probe")) |
| client.create_collection( |
| collection_name=name, |
| vectors_config=qmodels.VectorParams( |
| size=dim, |
| distance=qmodels.Distance.COSINE, |
| ), |
| ) |
|
|
| payload_schema = client.get_collection(name).payload_schema or {} |
|
|
| for field_name, field_schema in INDEXED_PAYLOAD_FIELDS.items(): |
| existing = payload_schema.get(field_name) |
| if existing is None: |
| client.create_payload_index( |
| collection_name=name, |
| field_name=field_name, |
| field_schema=field_schema, |
| ) |
| continue |
|
|
| existing_schema = getattr(existing, "data_type", None) |
| if existing_schema != field_schema: |
| raise ValueError( |
| f"Payload index for '{field_name}' has schema " |
| f"{existing_schema!r}, expected {field_schema!r}." |
| ) |
|
|
|
|
| def scroll_all( |
| collection_name: str, |
| scroll_filter: qmodels.Filter | None = None, |
| with_payload: bool | list[str] = True, |
| limit: int = _SCROLL_PAGE_SIZE, |
| ) -> Iterator[list]: |
| """Yield pages of Qdrant points (no vectors) until the collection is exhausted.""" |
| client = get_client() |
| offset = None |
| while True: |
| try: |
| points, next_offset = client.scroll( |
| collection_name=collection_name, |
| scroll_filter=scroll_filter, |
| limit=limit, |
| offset=offset, |
| with_payload=with_payload, |
| with_vectors=False, |
| ) |
| except ValueError as exc: |
| |
| if "not found" in str(exc).lower(): |
| return |
| raise |
| yield points |
| if next_offset is None: |
| break |
| offset = next_offset |
|
|
|
|
| def get_vector_store(collection_name: str | None = None) -> QdrantVectorStore: |
| return QdrantVectorStore( |
| client=get_client(), |
| collection_name=collection_name or settings.qdrant_collection, |
| embedding=get_embeddings(), |
| ) |
|
|
|
|
| def list_documents() -> list[dict[str, object]]: |
| """List indexed documents with filename, document_id, pages, and chunk counts. |
| |
| Returns one entry per filename matching the API `DocumentInfo` shape. |
| """ |
| pages_map: dict[str, set[int]] = {} |
| doc_id_map: dict[str, str] = {} |
| count_map: dict[str, int] = {} |
|
|
| for batch in scroll_all(settings.qdrant_collection, with_payload=["metadata"]): |
| for point in batch: |
| meta = (point.payload or {}).get("metadata") or {} |
| filename = meta.get("filename") |
| document_id = meta.get("document_id") |
| pg = meta.get("page") |
| if not filename or not document_id or not isinstance(pg, int): |
| continue |
| fn = str(filename) |
| doc_id_map.setdefault(fn, str(document_id)) |
| pages_map.setdefault(fn, set()).add(pg) |
| count_map[fn] = count_map.get(fn, 0) + 1 |
|
|
| return sorted( |
| [ |
| { |
| "filename": fn, |
| "document_id": doc_id_map[fn], |
| "pages": sorted(pages_map[fn]), |
| "page_count": len(pages_map[fn]), |
| "chunk_count": count_map[fn], |
| } |
| for fn in doc_id_map |
| ], |
| key=lambda d: str(d["filename"]), |
| ) |
|
|
|
|