Spaces:
Sleeping
Sleeping
| import logging | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import ( | |
| Distance, | |
| FieldCondition, | |
| Filter, | |
| MatchAny, | |
| MatchValue, | |
| PayloadSchemaType, | |
| PointIdsList, | |
| PointStruct, | |
| Range, | |
| VectorParams, | |
| ) | |
| from app.config import get_settings | |
| from app.models.document import Chunk | |
| from app.models.schemas import SearchFilters | |
| logger = logging.getLogger(__name__) | |
| class VectorStoreService: | |
| def __init__(self, url: str, api_key: str, collection_name: str): | |
| self.client = QdrantClient(url=url, api_key=api_key) | |
| self.collection_name = collection_name | |
| logger.info(f"Connected to Qdrant at {url}") | |
| def ensure_collection(self, vector_size: int = 384) -> None: | |
| collections = [c.name for c in self.client.get_collections().collections] | |
| if self.collection_name not in collections: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), | |
| ) | |
| logger.info(f"Created collection '{self.collection_name}' (dim={vector_size})") | |
| else: | |
| logger.info(f"Collection '{self.collection_name}' already exists") | |
| # Ensure payload indexes exist for filterable fields | |
| self._ensure_payload_indexes() | |
| def _ensure_payload_indexes(self) -> None: | |
| """Create payload indexes for fields used in filtering.""" | |
| index_fields = { | |
| "document_id": PayloadSchemaType.KEYWORD, | |
| "source": PayloadSchemaType.KEYWORD, | |
| "doc_type": PayloadSchemaType.KEYWORD, | |
| "tags": PayloadSchemaType.KEYWORD, | |
| "created_date": PayloadSchemaType.KEYWORD, | |
| } | |
| try: | |
| collection_info = self.client.get_collection(self.collection_name) | |
| existing_indexes = set(collection_info.payload_schema.keys()) if collection_info.payload_schema else set() | |
| except Exception: | |
| existing_indexes = set() | |
| for field_name, field_type in index_fields.items(): | |
| if field_name not in existing_indexes: | |
| try: | |
| self.client.create_payload_index( | |
| collection_name=self.collection_name, | |
| field_name=field_name, | |
| field_schema=field_type, | |
| ) | |
| logger.info(f"Created payload index: {field_name} ({field_type})") | |
| except Exception as e: | |
| logger.warning(f"Could not create index for '{field_name}': {e}") | |
| def upsert_chunks(self, chunks: list[Chunk], embeddings: list[list[float]]) -> None: | |
| batch_size = 100 | |
| for i in range(0, len(chunks), batch_size): | |
| batch_chunks = chunks[i : i + batch_size] | |
| batch_embeddings = embeddings[i : i + batch_size] | |
| points = [ | |
| PointStruct( | |
| id=chunk.chunk_id, | |
| vector=embedding, | |
| payload={ | |
| "text": chunk.text, | |
| "document_id": chunk.document_id, | |
| "chunk_index": chunk.chunk_index, | |
| "source": chunk.metadata.source, | |
| "doc_type": chunk.metadata.doc_type, | |
| "title": chunk.metadata.title, | |
| "created_date": chunk.metadata.created_date.isoformat() | |
| if chunk.metadata.created_date | |
| else None, | |
| "tags": chunk.metadata.tags, | |
| "page_count": chunk.metadata.page_count, | |
| }, | |
| ) | |
| for chunk, embedding in zip(batch_chunks, batch_embeddings) | |
| ] | |
| self.client.upsert(collection_name=self.collection_name, points=points) | |
| logger.info(f"Upserted {len(chunks)} chunks to '{self.collection_name}'") | |
| def search( | |
| self, | |
| query_vector: list[float], | |
| limit: int = 10, | |
| filters: SearchFilters | None = None, | |
| ) -> list[dict]: | |
| qdrant_filter = self._build_filter(filters) if filters and filters.has_filters() else None | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_vector, | |
| limit=limit, | |
| query_filter=qdrant_filter, | |
| ).points | |
| return [ | |
| { | |
| "chunk_id": str(r.id), | |
| "text": r.payload.get("text", ""), | |
| "score": r.score, | |
| "document_id": r.payload.get("document_id", ""), | |
| "metadata": { | |
| "source": r.payload.get("source", ""), | |
| "doc_type": r.payload.get("doc_type", ""), | |
| "title": r.payload.get("title"), | |
| "created_date": r.payload.get("created_date"), | |
| "tags": r.payload.get("tags", []), | |
| "page_count": r.payload.get("page_count"), | |
| }, | |
| } | |
| for r in results | |
| ] | |
| def delete_document(self, document_id: str) -> int: | |
| # First, find all point IDs belonging to this document | |
| doc_filter = Filter( | |
| must=[FieldCondition(key="document_id", match=MatchValue(value=document_id))] | |
| ) | |
| point_ids = [] | |
| offset = None | |
| while True: | |
| results, next_offset = self.client.scroll( | |
| collection_name=self.collection_name, | |
| scroll_filter=doc_filter, | |
| limit=100, | |
| offset=offset, | |
| with_payload=False, | |
| with_vectors=False, | |
| ) | |
| point_ids.extend([r.id for r in results]) | |
| if next_offset is None: | |
| break | |
| offset = next_offset | |
| if not point_ids: | |
| logger.warning(f"No points found for document '{document_id}'") | |
| return 0 | |
| # Delete by point IDs (requires only write permission, not manage) | |
| self.client.delete( | |
| collection_name=self.collection_name, | |
| points_selector=PointIdsList(points=point_ids), | |
| ) | |
| logger.info(f"Deleted {len(point_ids)} points for document '{document_id}'") | |
| return len(point_ids) | |
| def scroll_all(self, batch_size: int = 100) -> list[dict]: | |
| all_points = [] | |
| offset = None | |
| while True: | |
| results, next_offset = self.client.scroll( | |
| collection_name=self.collection_name, | |
| limit=batch_size, | |
| offset=offset, | |
| with_payload=True, | |
| with_vectors=False, | |
| ) | |
| for r in results: | |
| all_points.append({ | |
| "chunk_id": str(r.id), | |
| "text": r.payload.get("text", ""), | |
| "document_id": r.payload.get("document_id", ""), | |
| "metadata": { | |
| "source": r.payload.get("source", ""), | |
| "doc_type": r.payload.get("doc_type", ""), | |
| "title": r.payload.get("title"), | |
| "tags": r.payload.get("tags", []), | |
| }, | |
| }) | |
| if next_offset is None: | |
| break | |
| offset = next_offset | |
| return all_points | |
| def get_document_ids(self) -> list[dict]: | |
| all_points = self.scroll_all() | |
| docs: dict[str, dict] = {} | |
| for p in all_points: | |
| doc_id = p["document_id"] | |
| if doc_id not in docs: | |
| docs[doc_id] = { | |
| "document_id": doc_id, | |
| "source": p["metadata"]["source"], | |
| "title": p["metadata"].get("title"), | |
| "doc_type": p["metadata"]["doc_type"], | |
| "num_chunks": 0, | |
| } | |
| docs[doc_id]["num_chunks"] += 1 | |
| return list(docs.values()) | |
| def count(self) -> int: | |
| info = self.client.get_collection(self.collection_name) | |
| return info.points_count | |
| def _build_filter(filters: SearchFilters) -> Filter | None: | |
| conditions = [] | |
| if filters.source: | |
| conditions.append(FieldCondition(key="source", match=MatchValue(value=filters.source))) | |
| if filters.doc_type: | |
| conditions.append(FieldCondition(key="doc_type", match=MatchValue(value=filters.doc_type))) | |
| if filters.tags: | |
| conditions.append(FieldCondition(key="tags", match=MatchAny(any=filters.tags))) | |
| if filters.date_from or filters.date_to: | |
| range_params = {} | |
| if filters.date_from: | |
| range_params["gte"] = filters.date_from.isoformat() | |
| if filters.date_to: | |
| range_params["lte"] = filters.date_to.isoformat() | |
| conditions.append(FieldCondition(key="created_date", range=Range(**range_params))) | |
| return Filter(must=conditions) if conditions else None | |
| _vectorstore: VectorStoreService | None = None | |
| def get_vectorstore() -> VectorStoreService: | |
| global _vectorstore | |
| if _vectorstore is None: | |
| settings = get_settings() | |
| _vectorstore = VectorStoreService( | |
| url=settings.qdrant_url, | |
| api_key=settings.qdrant_api_key, | |
| collection_name=settings.qdrant_collection, | |
| ) | |
| _vectorstore.ensure_collection(vector_size=settings.embedding_dim) | |
| return _vectorstore | |