File size: 6,288 Bytes
bec06d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import uuid
import sys
from typing import List, Dict, Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models

# Add the current directory to the path so we can import config
sys.path.insert(0, os.path.dirname(__file__))
from config import QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME

import logging

logger = logging.getLogger(__name__)

class VectorStore:
    """
    A class to handle vector storage and retrieval using Qdrant.
    """
    
    def __init__(self):
        if QDRANT_API_KEY:
            self.client = QdrantClient(
                url=QDRANT_URL,
                api_key=QDRANT_API_KEY,
                prefer_grpc=True
            )
        else:
            self.client = QdrantClient(url=QDRANT_URL)
    
    def create_collection(self, vector_size: int = 1536):
        """Create a collection in Qdrant if it doesn't exist."""
        try:
            # Check if collection exists
            collections = self.client.get_collections().collections
            if not any(col.name == COLLECTION_NAME for col in collections):
                self.client.create_collection(
                    collection_name=COLLECTION_NAME,
                    vectors_config=models.VectorParams(
                        size=vector_size,
                        distance=models.Distance.COSINE
                    ),
                )
                logger.info(f"Created collection: {COLLECTION_NAME}")
            else:
                logger.info(f"Collection {COLLECTION_NAME} already exists")
        except Exception as e:
            logger.error(f"Error creating collection: {str(e)}")
            raise
    
    def add_documents(self, documents: List[Dict[str, Any]]):
        """Add documents with embeddings to the collection."""
        try:
            points = []

            for doc in documents:
                # Generate a unique ID for each document chunk
                point_id = str(uuid.uuid4())

                # Extract content, embedding, and metadata
                content = doc.get('content', '')
                embedding = doc.get('embedding', [])
                metadata = doc.get('metadata', {})

                # Create payload with all available metadata
                payload = {
                    "content": content,
                    "source": metadata.get('source', ''),
                    "file_name": metadata.get('file_name', ''),
                    "file_path": metadata.get('file_path', ''),
                }

                # Add additional metadata if available
                if 'chunk_id' in metadata:
                    payload['chunk_id'] = metadata['chunk_id']
                if 'total_chunks' in metadata:
                    payload['total_chunks'] = metadata['total_chunks']

                points.append(
                    models.PointStruct(
                        id=point_id,
                        vector=embedding,
                        payload=payload
                    )
                )

            # Upload points to the collection
            self.client.upload_points(
                collection_name=COLLECTION_NAME,
                points=points
            )

            logger.info(f"Added {len(points)} documents to collection {COLLECTION_NAME}")
        except Exception as e:
            logger.error(f"Error adding documents: {str(e)}")
            raise

    def delete_collection(self):
        """Delete the collection if it exists."""
        try:
            self.client.delete_collection(collection_name=COLLECTION_NAME)
            logger.info(f"Deleted collection: {COLLECTION_NAME}")
        except Exception as e:
            logger.error(f"Error deleting collection: {str(e)}")
            raise

    def delete_documents_by_source(self, source: str):
        """Delete documents that match a specific source."""
        try:
            # Find points with the matching source
            result = self.client.scroll(
                collection_name=COLLECTION_NAME,
                scroll_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="source",
                            match=models.MatchValue(value=source)
                        )
                    ]
                ),
                limit=10000  # Adjust as needed
            )

            # Extract IDs of matching points
            point_ids = [point.id for point in result[0]]

            if point_ids:
                # Delete the points
                self.client.delete(
                    collection_name=COLLECTION_NAME,
                    points_selector=models.PointIdsList(
                        points=point_ids
                    )
                )
                logger.info(f"Deleted {len(point_ids)} documents from source: {source}")
            else:
                logger.info(f"No documents found from source: {source}")
        except Exception as e:
            logger.error(f"Error deleting documents by source: {str(e)}")
            raise
    
    def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
        """Search for similar documents based on embedding."""
        try:
            results = self.client.search(
                collection_name=COLLECTION_NAME,
                query_vector=query_embedding,
                limit=top_k
            )
            
            hits = []
            for hit in results:
                hits.append({
                    'content': hit.payload.get('content', ''),
                    'source': hit.payload.get('source', ''),
                    'score': hit.score,
                    'id': hit.id
                })
            
            return hits
        except Exception as e:
            logger.error(f"Error searching for similar documents: {str(e)}")
            return []
    
    def get_all_documents_count(self) -> int:
        """Get the total number of documents in the collection."""
        try:
            info = self.client.get_collection(collection_name=COLLECTION_NAME)
            return info.points_count
        except Exception as e:
            logger.error(f"Error getting document count: {str(e)}")
            return 0