Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue | |
| from app.config import config, settings | |
| from app.utils.logger import logger | |
| from typing import List | |
| import uuid | |
| class VectorStore: | |
| def __init__(self): | |
| self.client = None | |
| self.collection_name = config["database"]["qdrant"]["collection_name"] | |
| def connect(self): | |
| if self.client is None: | |
| qdrant_url = config["database"]["qdrant"]["url"] | |
| api_key = settings.qdrant_api_key or None | |
| self.client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=api_key | |
| ) | |
| logger.info("Qdrant connected") | |
| return self.client | |
| def create_collection(self, vector_size: int = None): | |
| if vector_size is None: | |
| vector_size = config["database"]["qdrant"]["vector_size"] | |
| client = self.get_client() | |
| if not client.collection_exists(self.collection_name): | |
| client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=vector_size, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| logger.info(f"Created Qdrant collection: {self.collection_name}") | |
| else: | |
| logger.info(f"Qdrant collection already exists: {self.collection_name}") | |
| def get_client(self): | |
| if self.client is None: | |
| self.connect() | |
| return self.client | |
| async def add_documents(self, collection_name: str, documents: List, embeddings: List[List[float]]): | |
| client = self.get_client() | |
| points = [] | |
| for i, (doc, embedding) in enumerate(zip(documents, embeddings)): | |
| point_id = str(uuid.uuid4()) | |
| points.append( | |
| PointStruct( | |
| id=point_id, | |
| vector=embedding, | |
| payload={ | |
| "text": doc.page_content, | |
| **doc.metadata | |
| } | |
| ) | |
| ) | |
| client.upsert( | |
| collection_name=collection_name, | |
| points=points | |
| ) | |
| logger.info(f"Added {len(points)} documents to Qdrant") | |
| return [p.id for p in points] | |
| async def delete_by_metadata(self, collection_name: str, metadata_key: str, metadata_value: str): | |
| client = self.get_client() | |
| client.delete( | |
| collection_name=collection_name, | |
| points_selector=Filter( | |
| must=[ | |
| FieldCondition( | |
| key=metadata_key, | |
| match=MatchValue(value=metadata_value) | |
| ) | |
| ] | |
| ) | |
| ) | |
| logger.info(f"Deleted documents with {metadata_key}={metadata_value} from Qdrant") | |
| vector_store = VectorStore() | |