Spaces:
Configuration error
Configuration error
| import uuid | |
| import sys | |
| from typing import List, Dict, Any, Optional | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| # Add the current directory to the path so we can import config | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from config import QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class VectorStore: | |
| """ | |
| A class to handle vector storage and retrieval using Qdrant. | |
| """ | |
| def __init__(self): | |
| if QDRANT_API_KEY: | |
| self.client = QdrantClient( | |
| url=QDRANT_URL, | |
| api_key=QDRANT_API_KEY, | |
| prefer_grpc=True | |
| ) | |
| else: | |
| self.client = QdrantClient(url=QDRANT_URL) | |
| def create_collection(self, vector_size: int = 1536): | |
| """Create a collection in Qdrant if it doesn't exist.""" | |
| try: | |
| # Check if collection exists | |
| collections = self.client.get_collections().collections | |
| if not any(col.name == COLLECTION_NAME for col in collections): | |
| self.client.create_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=models.VectorParams( | |
| size=vector_size, | |
| distance=models.Distance.COSINE | |
| ), | |
| ) | |
| logger.info(f"Created collection: {COLLECTION_NAME}") | |
| else: | |
| logger.info(f"Collection {COLLECTION_NAME} already exists") | |
| except Exception as e: | |
| logger.error(f"Error creating collection: {str(e)}") | |
| raise | |
| def add_documents(self, documents: List[Dict[str, Any]]): | |
| """Add documents with embeddings to the collection.""" | |
| try: | |
| points = [] | |
| for doc in documents: | |
| # Generate a unique ID for each document chunk | |
| point_id = str(uuid.uuid4()) | |
| # Extract content, embedding, and metadata | |
| content = doc.get('content', '') | |
| embedding = doc.get('embedding', []) | |
| metadata = doc.get('metadata', {}) | |
| # Create payload with all available metadata | |
| payload = { | |
| "content": content, | |
| "source": metadata.get('source', ''), | |
| "file_name": metadata.get('file_name', ''), | |
| "file_path": metadata.get('file_path', ''), | |
| } | |
| # Add additional metadata if available | |
| if 'chunk_id' in metadata: | |
| payload['chunk_id'] = metadata['chunk_id'] | |
| if 'total_chunks' in metadata: | |
| payload['total_chunks'] = metadata['total_chunks'] | |
| points.append( | |
| models.PointStruct( | |
| id=point_id, | |
| vector=embedding, | |
| payload=payload | |
| ) | |
| ) | |
| # Upload points to the collection | |
| self.client.upload_points( | |
| collection_name=COLLECTION_NAME, | |
| points=points | |
| ) | |
| logger.info(f"Added {len(points)} documents to collection {COLLECTION_NAME}") | |
| except Exception as e: | |
| logger.error(f"Error adding documents: {str(e)}") | |
| raise | |
| def delete_collection(self): | |
| """Delete the collection if it exists.""" | |
| try: | |
| self.client.delete_collection(collection_name=COLLECTION_NAME) | |
| logger.info(f"Deleted collection: {COLLECTION_NAME}") | |
| except Exception as e: | |
| logger.error(f"Error deleting collection: {str(e)}") | |
| raise | |
| def delete_documents_by_source(self, source: str): | |
| """Delete documents that match a specific source.""" | |
| try: | |
| # Find points with the matching source | |
| result = self.client.scroll( | |
| collection_name=COLLECTION_NAME, | |
| scroll_filter=models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="source", | |
| match=models.MatchValue(value=source) | |
| ) | |
| ] | |
| ), | |
| limit=10000 # Adjust as needed | |
| ) | |
| # Extract IDs of matching points | |
| point_ids = [point.id for point in result[0]] | |
| if point_ids: | |
| # Delete the points | |
| self.client.delete( | |
| collection_name=COLLECTION_NAME, | |
| points_selector=models.PointIdsList( | |
| points=point_ids | |
| ) | |
| ) | |
| logger.info(f"Deleted {len(point_ids)} documents from source: {source}") | |
| else: | |
| logger.info(f"No documents found from source: {source}") | |
| except Exception as e: | |
| logger.error(f"Error deleting documents by source: {str(e)}") | |
| raise | |
| def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]: | |
| """Search for similar documents based on embedding.""" | |
| try: | |
| results = self.client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_embedding, | |
| limit=top_k | |
| ) | |
| hits = [] | |
| for hit in results: | |
| hits.append({ | |
| 'content': hit.payload.get('content', ''), | |
| 'source': hit.payload.get('source', ''), | |
| 'score': hit.score, | |
| 'id': hit.id | |
| }) | |
| return hits | |
| except Exception as e: | |
| logger.error(f"Error searching for similar documents: {str(e)}") | |
| return [] | |
| def get_all_documents_count(self) -> int: | |
| """Get the total number of documents in the collection.""" | |
| try: | |
| info = self.client.get_collection(collection_name=COLLECTION_NAME) | |
| return info.points_count | |
| except Exception as e: | |
| logger.error(f"Error getting document count: {str(e)}") | |
| return 0 |