File size: 4,923 Bytes
16fa4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Embeddings, Qdrant client, collection setup, and vector store."""

from __future__ import annotations

from collections.abc import Iterator
from functools import lru_cache

from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels

from src.config import settings
from src.embeddings import get_embeddings

_SCROLL_PAGE_SIZE = 256

INDEXED_PAYLOAD_FIELDS = {
    "metadata.document_id": qmodels.PayloadSchemaType.KEYWORD,
    "metadata.filename": qmodels.PayloadSchemaType.KEYWORD,
    "metadata.page": qmodels.PayloadSchemaType.INTEGER,
}


def close_client() -> None:
    if get_client.cache_info().currsize == 0:
        return
    client = get_client()
    client.close()
    get_client.cache_clear()


@lru_cache(maxsize=1)
def get_client() -> QdrantClient:
    """Return a cached local Qdrant client backed by on-disk storage."""
    settings.storage_dir.mkdir(parents=True, exist_ok=True)
    return QdrantClient(path=str(settings.storage_dir))


def ensure_collection(recreate: bool = False, collection_name: str | None = None) -> None:
    """Create the collection and payload indexes if they do not exist."""
    client = get_client()
    name = collection_name or settings.qdrant_collection

    exists = client.collection_exists(name)
    if exists and recreate:
        client.delete_collection(name)
        exists = False

    if not exists:
        dim = len(get_embeddings().embed_query("dimension probe"))
        client.create_collection(
            collection_name=name,
            vectors_config=qmodels.VectorParams(
                size=dim,
                distance=qmodels.Distance.COSINE,
            ),
        )

    payload_schema = client.get_collection(name).payload_schema or {}

    for field_name, field_schema in INDEXED_PAYLOAD_FIELDS.items():
        existing = payload_schema.get(field_name)
        if existing is None:
            client.create_payload_index(
                collection_name=name,
                field_name=field_name,
                field_schema=field_schema,
            )
            continue

        existing_schema = getattr(existing, "data_type", None)
        if existing_schema != field_schema:
            raise ValueError(
                f"Payload index for '{field_name}' has schema "
                f"{existing_schema!r}, expected {field_schema!r}."
            )


def scroll_all(
    collection_name: str,
    scroll_filter: qmodels.Filter | None = None,
    with_payload: bool | list[str] = True,
    limit: int = _SCROLL_PAGE_SIZE,
) -> Iterator[list]:
    """Yield pages of Qdrant points (no vectors) until the collection is exhausted."""
    client = get_client()
    offset = None
    while True:
        try:
            points, next_offset = client.scroll(
                collection_name=collection_name,
                scroll_filter=scroll_filter,
                limit=limit,
                offset=offset,
                with_payload=with_payload,
                with_vectors=False,
            )
        except ValueError as exc:
            # Local Qdrant raises ValueError when collection doesn't exist yet.
            if "not found" in str(exc).lower():
                return
            raise
        yield points
        if next_offset is None:
            break
        offset = next_offset


def get_vector_store(collection_name: str | None = None) -> QdrantVectorStore:
    return QdrantVectorStore(
        client=get_client(),
        collection_name=collection_name or settings.qdrant_collection,
        embedding=get_embeddings(),
    )


def list_documents() -> list[dict[str, object]]:
    """List indexed documents with filename, document_id, pages, and chunk counts.

    Returns one entry per filename matching the API `DocumentInfo` shape.
    """
    pages_map: dict[str, set[int]] = {}
    doc_id_map: dict[str, str] = {}
    count_map: dict[str, int] = {}

    for batch in scroll_all(settings.qdrant_collection, with_payload=["metadata"]):
        for point in batch:
            meta = (point.payload or {}).get("metadata") or {}
            filename = meta.get("filename")
            document_id = meta.get("document_id")
            pg = meta.get("page")
            if not filename or not document_id or not isinstance(pg, int):
                continue
            fn = str(filename)
            doc_id_map.setdefault(fn, str(document_id))
            pages_map.setdefault(fn, set()).add(pg)
            count_map[fn] = count_map.get(fn, 0) + 1

    return sorted(
        [
            {
                "filename": fn,
                "document_id": doc_id_map[fn],
                "pages": sorted(pages_map[fn]),
                "page_count": len(pages_map[fn]),
                "chunk_count": count_map[fn],
            }
            for fn in doc_id_map
        ],
        key=lambda d: str(d["filename"]),
    )