Spaces:
Sleeping
Sleeping
| """Qdrant vector database adapter.""" | |
| from app.ports.vector_db import VectorDBPort, VectorChunk, VectorSearchResult | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue | |
| from typing import List | |
| from app.config import get_settings | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| class QdrantAdapter(VectorDBPort): | |
| """Qdrant implementation of VectorDBPort.""" | |
| def __init__(self): | |
| try: | |
| # Check if using Qdrant Cloud (with API key) | |
| if settings.QDRANT_API_KEY: | |
| self.client = QdrantClient( | |
| url=f"https://{settings.QDRANT_HOST}:{settings.QDRANT_PORT}", | |
| api_key=settings.QDRANT_API_KEY, | |
| timeout=10.0 | |
| ) | |
| logger.info(f"Connected to Qdrant Cloud at {settings.QDRANT_HOST}") | |
| else: | |
| # Local Qdrant instance | |
| self.client = QdrantClient( | |
| host=settings.QDRANT_HOST, | |
| port=settings.QDRANT_PORT, | |
| timeout=5.0 | |
| ) | |
| logger.info(f"Connected to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}") | |
| except Exception as e: | |
| logger.warning(f"Failed to connect to Qdrant: {e}") | |
| self.client = None | |
| async def initialize_collection(self, collection_name: str, dimension: int) -> None: | |
| """Initialize vector collection.""" | |
| if self.client is None: | |
| logger.warning("Qdrant client not available - skipping collection initialization") | |
| return | |
| try: | |
| collections = self.client.get_collections().collections | |
| exists = any(c.name == collection_name for c in collections) | |
| if not exists: | |
| self.client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams( | |
| size=dimension, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| logger.info(f"Created collection: {collection_name}") | |
| else: | |
| logger.info(f"Collection already exists: {collection_name}") | |
| except Exception as e: | |
| logger.error(f"Error initializing collection: {e}") | |
| raise | |
| async def store_chunks(self, chunks: List[VectorChunk], collection_name: str) -> None: | |
| """Store vector chunks.""" | |
| try: | |
| points = [ | |
| PointStruct( | |
| id=chunk.id, | |
| vector=chunk.embedding, | |
| payload={ | |
| "document_id": chunk.document_id, | |
| "chunk_index": chunk.chunk_index, | |
| "text": chunk.text, | |
| **chunk.metadata | |
| } | |
| ) | |
| for chunk in chunks | |
| ] | |
| self.client.upsert( | |
| collection_name=collection_name, | |
| points=points | |
| ) | |
| logger.info(f"Stored {len(chunks)} chunks in {collection_name}") | |
| except Exception as e: | |
| logger.error(f"Error storing chunks: {e}") | |
| raise | |
| async def search( | |
| self, | |
| query_embedding: List[float], | |
| org_id: str, | |
| top_k: int, | |
| collection_name: str | |
| ) -> List[VectorSearchResult]: | |
| """Search for similar vectors.""" | |
| try: | |
| results = self.client.search( | |
| collection_name=collection_name, | |
| query_vector=query_embedding, | |
| query_filter=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="org_id", | |
| match=MatchValue(value=org_id) | |
| ) | |
| ] | |
| ), | |
| limit=top_k | |
| ) | |
| search_results = [ | |
| VectorSearchResult( | |
| document_id=hit.payload["document_id"], | |
| chunk_index=hit.payload["chunk_index"], | |
| text=hit.payload["text"], | |
| score=hit.score, | |
| metadata={k: v for k, v in hit.payload.items() | |
| if k not in ["document_id", "chunk_index", "text"]} | |
| ) | |
| for hit in results | |
| ] | |
| logger.info(f"Found {len(search_results)} results for org {org_id}") | |
| return search_results | |
| except Exception as e: | |
| logger.error(f"Error searching vectors: {e}") | |
| raise | |
| async def delete_document(self, document_id: str, collection_name: str) -> None: | |
| """Delete all chunks for a document.""" | |
| try: | |
| self.client.delete( | |
| collection_name=collection_name, | |
| points_selector=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="document_id", | |
| match=MatchValue(value=document_id) | |
| ) | |
| ] | |
| ) | |
| ) | |
| logger.info(f"Deleted chunks for document {document_id}") | |
| except Exception as e: | |
| logger.error(f"Error deleting document chunks: {e}") | |
| raise | |