File size: 4,923 Bytes
16fa4e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """Embeddings, Qdrant client, collection setup, and vector store."""
from __future__ import annotations
from collections.abc import Iterator
from functools import lru_cache
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from src.config import settings
from src.embeddings import get_embeddings
_SCROLL_PAGE_SIZE = 256
INDEXED_PAYLOAD_FIELDS = {
"metadata.document_id": qmodels.PayloadSchemaType.KEYWORD,
"metadata.filename": qmodels.PayloadSchemaType.KEYWORD,
"metadata.page": qmodels.PayloadSchemaType.INTEGER,
}
def close_client() -> None:
if get_client.cache_info().currsize == 0:
return
client = get_client()
client.close()
get_client.cache_clear()
@lru_cache(maxsize=1)
def get_client() -> QdrantClient:
"""Return a cached local Qdrant client backed by on-disk storage."""
settings.storage_dir.mkdir(parents=True, exist_ok=True)
return QdrantClient(path=str(settings.storage_dir))
def ensure_collection(recreate: bool = False, collection_name: str | None = None) -> None:
"""Create the collection and payload indexes if they do not exist."""
client = get_client()
name = collection_name or settings.qdrant_collection
exists = client.collection_exists(name)
if exists and recreate:
client.delete_collection(name)
exists = False
if not exists:
dim = len(get_embeddings().embed_query("dimension probe"))
client.create_collection(
collection_name=name,
vectors_config=qmodels.VectorParams(
size=dim,
distance=qmodels.Distance.COSINE,
),
)
payload_schema = client.get_collection(name).payload_schema or {}
for field_name, field_schema in INDEXED_PAYLOAD_FIELDS.items():
existing = payload_schema.get(field_name)
if existing is None:
client.create_payload_index(
collection_name=name,
field_name=field_name,
field_schema=field_schema,
)
continue
existing_schema = getattr(existing, "data_type", None)
if existing_schema != field_schema:
raise ValueError(
f"Payload index for '{field_name}' has schema "
f"{existing_schema!r}, expected {field_schema!r}."
)
def scroll_all(
collection_name: str,
scroll_filter: qmodels.Filter | None = None,
with_payload: bool | list[str] = True,
limit: int = _SCROLL_PAGE_SIZE,
) -> Iterator[list]:
"""Yield pages of Qdrant points (no vectors) until the collection is exhausted."""
client = get_client()
offset = None
while True:
try:
points, next_offset = client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=limit,
offset=offset,
with_payload=with_payload,
with_vectors=False,
)
except ValueError as exc:
# Local Qdrant raises ValueError when collection doesn't exist yet.
if "not found" in str(exc).lower():
return
raise
yield points
if next_offset is None:
break
offset = next_offset
def get_vector_store(collection_name: str | None = None) -> QdrantVectorStore:
return QdrantVectorStore(
client=get_client(),
collection_name=collection_name or settings.qdrant_collection,
embedding=get_embeddings(),
)
def list_documents() -> list[dict[str, object]]:
"""List indexed documents with filename, document_id, pages, and chunk counts.
Returns one entry per filename matching the API `DocumentInfo` shape.
"""
pages_map: dict[str, set[int]] = {}
doc_id_map: dict[str, str] = {}
count_map: dict[str, int] = {}
for batch in scroll_all(settings.qdrant_collection, with_payload=["metadata"]):
for point in batch:
meta = (point.payload or {}).get("metadata") or {}
filename = meta.get("filename")
document_id = meta.get("document_id")
pg = meta.get("page")
if not filename or not document_id or not isinstance(pg, int):
continue
fn = str(filename)
doc_id_map.setdefault(fn, str(document_id))
pages_map.setdefault(fn, set()).add(pg)
count_map[fn] = count_map.get(fn, 0) + 1
return sorted(
[
{
"filename": fn,
"document_id": doc_id_map[fn],
"pages": sorted(pages_map[fn]),
"page_count": len(pages_map[fn]),
"chunk_count": count_map[fn],
}
for fn in doc_id_map
],
key=lambda d: str(d["filename"]),
)
|