DenysKovalML's picture
fix: qdrant deploy
0f95a58
from typing import Any
from loguru import logger
from qdrant_client import QdrantClient as SyncQdrantClient
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
Modifier,
PointStruct,
SparseIndexParams,
SparseVector,
SparseVectorParams,
VectorParams,
)
from scientific_rag.domain.documents import PaperChunk
from scientific_rag.domain.queries import QueryFilters
from scientific_rag.settings import settings
class QdrantService:
def __init__(self) -> None:
self.url = settings.qdrant_url
self.api_key = settings.qdrant_api_key
self.collection_name = settings.qdrant_collection_name
logger.info(f"Initializing Qdrant client: {self.url}")
if self.url == ":memory:":
self.client = SyncQdrantClient(location=":memory:", timeout=30)
else:
self.client = SyncQdrantClient(
url=self.url, api_key=self.api_key, timeout=30
)
def create_collection(
self, vector_size: int = 384, distance: Distance = Distance.COSINE
) -> None:
if self.client.collection_exists(self.collection_name):
logger.info(f"Collection '{self.collection_name}' already exists")
return
logger.info(
f"Creating collection '{self.collection_name}' with dense and sparse vectors"
)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config={
"dense": VectorParams(size=vector_size, distance=distance),
},
sparse_vectors_config={
"bm25": SparseVectorParams(
index=SparseIndexParams(
on_disk=True,
),
modifier=Modifier.IDF,
)
},
)
for field in ["source", "section", "paper_id"]:
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field,
field_schema="keyword",
)
logger.info(f"Collection '{self.collection_name}' created with indexes")
def upsert_chunks(
self, chunks: list[PaperChunk], sparse_embeddings: list[Any] | None = None
) -> int:
if not chunks:
return 0
points = []
for i, chunk in enumerate(chunks):
if chunk.embedding is None:
continue
vectors = {"dense": chunk.embedding}
if sparse_embeddings and i < len(sparse_embeddings):
sparse = sparse_embeddings[i]
vectors["bm25"] = SparseVector(
indices=sparse.indices.tolist(), values=sparse.values.tolist()
)
points.append(
PointStruct(
id=chunk.chunk_id,
vector=vectors,
payload=chunk.to_dict(),
)
)
self.client.upload_points(
collection_name=self.collection_name,
points=points,
wait=True,
)
logger.info(f"Uploaded {len(points)} chunks to Qdrant")
return len(points)
def search_dense(
self,
query_vector: list[float],
limit: int = 10,
filters: QueryFilters | None = None,
) -> list[PaperChunk]:
"""Standard semantic search using dense vectors."""
return self._execute_search(
vector_name="dense",
vector_data=query_vector,
is_sparse=False,
limit=limit,
filters=filters,
)
def search_sparse(
self,
query_sparse_indices: list[int],
query_sparse_values: list[float],
limit: int = 10,
filters: QueryFilters | None = None,
) -> list[PaperChunk]:
"""BM25-style search using sparse vectors."""
sparse_vec = SparseVector(
indices=query_sparse_indices,
values=query_sparse_values,
)
return self._execute_search(
vector_name="bm25",
vector_data=sparse_vec,
is_sparse=True,
limit=limit,
filters=filters,
)
def _execute_search(
self,
vector_name: str,
vector_data: Any,
is_sparse: bool,
limit: int,
filters: QueryFilters | None,
) -> list[PaperChunk]:
query_filter = self._build_filters(filters) if filters else None
results = self.client.query_points(
collection_name=self.collection_name,
using=vector_name,
query=vector_data,
query_filter=query_filter,
limit=limit,
with_payload=True,
)
return [
PaperChunk(
**hit.payload,
score=hit.score,
embedding=None,
)
for hit in results.points
]
def _build_filters(self, filters: QueryFilters) -> Filter | None:
filter_dict = filters.to_qdrant_filter()
if not filter_dict:
return None
must_conditions = []
target_list = (
filter_dict.get("must", []) if "must" in filter_dict else [filter_dict]
)
for item in target_list:
if "key" in item and "match" in item:
must_conditions.append(
FieldCondition(
key=item["key"], match=MatchValue(value=item["match"]["value"])
)
)
return Filter(must=must_conditions) if must_conditions else None
def get_collection_info(self) -> dict[str, Any]:
if not self.client.collection_exists(self.collection_name):
return {"exists": False}
info = self.client.get_collection(self.collection_name)
return {"exists": True, "points_count": info.points_count}
def close(self):
self.client.close()
logger.info("Qdrant client closed")
qdrant_service = QdrantService()