duongtruongbinh's picture
Init project
16fa4e7
"""Embeddings, Qdrant client, collection setup, and vector store."""
from __future__ import annotations
from collections.abc import Iterator
from functools import lru_cache
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from src.config import settings
from src.embeddings import get_embeddings
_SCROLL_PAGE_SIZE = 256
INDEXED_PAYLOAD_FIELDS = {
"metadata.document_id": qmodels.PayloadSchemaType.KEYWORD,
"metadata.filename": qmodels.PayloadSchemaType.KEYWORD,
"metadata.page": qmodels.PayloadSchemaType.INTEGER,
}
def close_client() -> None:
if get_client.cache_info().currsize == 0:
return
client = get_client()
client.close()
get_client.cache_clear()
@lru_cache(maxsize=1)
def get_client() -> QdrantClient:
"""Return a cached local Qdrant client backed by on-disk storage."""
settings.storage_dir.mkdir(parents=True, exist_ok=True)
return QdrantClient(path=str(settings.storage_dir))
def ensure_collection(recreate: bool = False, collection_name: str | None = None) -> None:
"""Create the collection and payload indexes if they do not exist."""
client = get_client()
name = collection_name or settings.qdrant_collection
exists = client.collection_exists(name)
if exists and recreate:
client.delete_collection(name)
exists = False
if not exists:
dim = len(get_embeddings().embed_query("dimension probe"))
client.create_collection(
collection_name=name,
vectors_config=qmodels.VectorParams(
size=dim,
distance=qmodels.Distance.COSINE,
),
)
payload_schema = client.get_collection(name).payload_schema or {}
for field_name, field_schema in INDEXED_PAYLOAD_FIELDS.items():
existing = payload_schema.get(field_name)
if existing is None:
client.create_payload_index(
collection_name=name,
field_name=field_name,
field_schema=field_schema,
)
continue
existing_schema = getattr(existing, "data_type", None)
if existing_schema != field_schema:
raise ValueError(
f"Payload index for '{field_name}' has schema "
f"{existing_schema!r}, expected {field_schema!r}."
)
def scroll_all(
collection_name: str,
scroll_filter: qmodels.Filter | None = None,
with_payload: bool | list[str] = True,
limit: int = _SCROLL_PAGE_SIZE,
) -> Iterator[list]:
"""Yield pages of Qdrant points (no vectors) until the collection is exhausted."""
client = get_client()
offset = None
while True:
try:
points, next_offset = client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=limit,
offset=offset,
with_payload=with_payload,
with_vectors=False,
)
except ValueError as exc:
# Local Qdrant raises ValueError when collection doesn't exist yet.
if "not found" in str(exc).lower():
return
raise
yield points
if next_offset is None:
break
offset = next_offset
def get_vector_store(collection_name: str | None = None) -> QdrantVectorStore:
return QdrantVectorStore(
client=get_client(),
collection_name=collection_name or settings.qdrant_collection,
embedding=get_embeddings(),
)
def list_documents() -> list[dict[str, object]]:
"""List indexed documents with filename, document_id, pages, and chunk counts.
Returns one entry per filename matching the API `DocumentInfo` shape.
"""
pages_map: dict[str, set[int]] = {}
doc_id_map: dict[str, str] = {}
count_map: dict[str, int] = {}
for batch in scroll_all(settings.qdrant_collection, with_payload=["metadata"]):
for point in batch:
meta = (point.payload or {}).get("metadata") or {}
filename = meta.get("filename")
document_id = meta.get("document_id")
pg = meta.get("page")
if not filename or not document_id or not isinstance(pg, int):
continue
fn = str(filename)
doc_id_map.setdefault(fn, str(document_id))
pages_map.setdefault(fn, set()).add(pg)
count_map[fn] = count_map.get(fn, 0) + 1
return sorted(
[
{
"filename": fn,
"document_id": doc_id_map[fn],
"pages": sorted(pages_map[fn]),
"page_count": len(pages_map[fn]),
"chunk_count": count_map[fn],
}
for fn in doc_id_map
],
key=lambda d: str(d["filename"]),
)