import uuid import sys from typing import List, Dict, Any, Optional from qdrant_client import QdrantClient from qdrant_client.http import models # Add the current directory to the path so we can import config sys.path.insert(0, os.path.dirname(__file__)) from config import QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME import logging logger = logging.getLogger(__name__) class VectorStore: """ A class to handle vector storage and retrieval using Qdrant. """ def __init__(self): if QDRANT_API_KEY: self.client = QdrantClient( url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=True ) else: self.client = QdrantClient(url=QDRANT_URL) def create_collection(self, vector_size: int = 1536): """Create a collection in Qdrant if it doesn't exist.""" try: # Check if collection exists collections = self.client.get_collections().collections if not any(col.name == COLLECTION_NAME for col in collections): self.client.create_collection( collection_name=COLLECTION_NAME, vectors_config=models.VectorParams( size=vector_size, distance=models.Distance.COSINE ), ) logger.info(f"Created collection: {COLLECTION_NAME}") else: logger.info(f"Collection {COLLECTION_NAME} already exists") except Exception as e: logger.error(f"Error creating collection: {str(e)}") raise def add_documents(self, documents: List[Dict[str, Any]]): """Add documents with embeddings to the collection.""" try: points = [] for doc in documents: # Generate a unique ID for each document chunk point_id = str(uuid.uuid4()) # Extract content, embedding, and metadata content = doc.get('content', '') embedding = doc.get('embedding', []) metadata = doc.get('metadata', {}) # Create payload with all available metadata payload = { "content": content, "source": metadata.get('source', ''), "file_name": metadata.get('file_name', ''), "file_path": metadata.get('file_path', ''), } # Add additional metadata if available if 'chunk_id' in metadata: payload['chunk_id'] = metadata['chunk_id'] if 'total_chunks' in metadata: payload['total_chunks'] = metadata['total_chunks'] points.append( models.PointStruct( id=point_id, vector=embedding, payload=payload ) ) # Upload points to the collection self.client.upload_points( collection_name=COLLECTION_NAME, points=points ) logger.info(f"Added {len(points)} documents to collection {COLLECTION_NAME}") except Exception as e: logger.error(f"Error adding documents: {str(e)}") raise def delete_collection(self): """Delete the collection if it exists.""" try: self.client.delete_collection(collection_name=COLLECTION_NAME) logger.info(f"Deleted collection: {COLLECTION_NAME}") except Exception as e: logger.error(f"Error deleting collection: {str(e)}") raise def delete_documents_by_source(self, source: str): """Delete documents that match a specific source.""" try: # Find points with the matching source result = self.client.scroll( collection_name=COLLECTION_NAME, scroll_filter=models.Filter( must=[ models.FieldCondition( key="source", match=models.MatchValue(value=source) ) ] ), limit=10000 # Adjust as needed ) # Extract IDs of matching points point_ids = [point.id for point in result[0]] if point_ids: # Delete the points self.client.delete( collection_name=COLLECTION_NAME, points_selector=models.PointIdsList( points=point_ids ) ) logger.info(f"Deleted {len(point_ids)} documents from source: {source}") else: logger.info(f"No documents found from source: {source}") except Exception as e: logger.error(f"Error deleting documents by source: {str(e)}") raise def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]: """Search for similar documents based on embedding.""" try: results = self.client.search( collection_name=COLLECTION_NAME, query_vector=query_embedding, limit=top_k ) hits = [] for hit in results: hits.append({ 'content': hit.payload.get('content', ''), 'source': hit.payload.get('source', ''), 'score': hit.score, 'id': hit.id }) return hits except Exception as e: logger.error(f"Error searching for similar documents: {str(e)}") return [] def get_all_documents_count(self) -> int: """Get the total number of documents in the collection.""" try: info = self.client.get_collection(collection_name=COLLECTION_NAME) return info.points_count except Exception as e: logger.error(f"Error getting document count: {str(e)}") return 0