File size: 11,221 Bytes
c71d352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
import pickle
from typing import List, Dict, Any, Optional
import numpy as np

# Vector store
import faiss
from sentence_transformers import SentenceTransformer

from config import Config

class VectorStore:
    """FAISS-based vector store for document embeddings"""
    
    def __init__(self, embedding_model: SentenceTransformer, config: Config = None):
        self.config = config or Config()
        self.embedding_model = embedding_model
        
        # Get embedding dimension
        self.dimension = embedding_model.get_sentence_embedding_dimension()
        
        # Initialize FAISS index
        self.index = faiss.IndexFlatIP(self.dimension)  # Inner product for cosine similarity
        
        # Storage for chunks and metadata
        self.chunks = []
        self.metadata = []
        self.file_map = {}  # Map file_id to chunk indices
        
        print(f"βœ… Vector store initialized with dimension: {self.dimension}")
    
    def add_documents(self, chunks: List[str], file_id: str, filename: str):
        """Add documents to vector store"""
        if not chunks:
            print("Warning: No chunks to add")
            return
        
        print(f"πŸ“ Adding {len(chunks)} chunks from {filename}")
        
        try:
            # Generate embeddings
            embeddings = self.embedding_model.encode(
                chunks,
                convert_to_numpy=True,
                show_progress_bar=len(chunks) > 10
            )
            
            # Ensure embeddings are float32
            embeddings = embeddings.astype(np.float32)
            
            # Normalize embeddings for cosine similarity with inner product
            faiss.normalize_L2(embeddings)
            
            # Add to FAISS index
            start_idx = len(self.chunks)
            self.index.add(embeddings)
            
            # Store chunks and metadata
            chunk_indices = []
            for i, chunk in enumerate(chunks):
                chunk_idx = start_idx + i
                chunk_indices.append(chunk_idx)
                
                self.chunks.append(chunk)
                self.metadata.append({
                    'file_id': file_id,
                    'filename': filename,
                    'chunk_index': i,
                    'global_index': chunk_idx,
                    'text': chunk,
                    'embedding_added': True
                })
            
            # Update file mapping
            if file_id not in self.file_map:
                self.file_map[file_id] = []
            self.file_map[file_id].extend(chunk_indices)
            
            print(f"βœ… Successfully added {len(chunks)} chunks. Total chunks: {len(self.chunks)}")
            
        except Exception as e:
            print(f"❌ Error adding documents: {e}")
            raise
    
    def search(self, query: str, k: int = 5, file_id: Optional[str] = None) -> List[Dict[str, Any]]:
        """Search for similar documents"""
        if len(self.chunks) == 0:
            print("Warning: No documents in vector store")
            return []
        
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
            query_embedding = query_embedding.astype(np.float32)
            
            # Normalize for cosine similarity
            faiss.normalize_L2(query_embedding)
            
            # Search
            search_k = min(k, len(self.chunks))  # Don't search for more than available
            scores, indices = self.index.search(query_embedding, search_k)
            
            results = []
            for score, idx in zip(scores[0], indices[0]):
                if idx != -1 and idx < len(self.chunks):  # Valid index
                    # Filter by file_id if specified
                    if file_id and self.metadata[idx]['file_id'] != file_id:
                        continue
                    
                    result = {
                        'text': self.chunks[idx],
                        'metadata': self.metadata[idx].copy(),
                        'score': float(score),
                        'similarity': float(score)  # Alias for compatibility
                    }
                    results.append(result)
            
            # Sort by score (highest first)
            results.sort(key=lambda x: x['score'], reverse=True)
            
            print(f"πŸ” Found {len(results)} results for query: '{query[:50]}...'")
            return results[:k]  # Return top k results
            
        except Exception as e:
            print(f"❌ Search error: {e}")
            return []
    
    def get_document_stats(self) -> Dict[str, Any]:
        """Get statistics about stored documents"""
        stats = {
            'total_chunks': len(self.chunks),
            'total_files': len(self.file_map),
            'index_size': self.index.ntotal,
            'dimension': self.dimension
        }
        
        # File-wise statistics
        file_stats = {}
        for file_id, chunk_indices in self.file_map.items():
            filename = self.metadata[chunk_indices[0]]['filename'] if chunk_indices else 'unknown'
            file_stats[file_id] = {
                'filename': filename,
                'chunk_count': len(chunk_indices),
                'chunk_indices': chunk_indices
            }
        
        stats['files'] = file_stats
        return stats
    
    def remove_file(self, file_id: str) -> bool:
        """Remove all chunks for a specific file"""
        if file_id not in self.file_map:
            print(f"Warning: File {file_id} not found in vector store")
            return False
        
        try:
            # Get chunk indices for this file
            chunk_indices = self.file_map[file_id]
            
            # Remove from file map
            del self.file_map[file_id]
            
            # Mark chunks as removed (we can't actually remove from FAISS index)
            for idx in chunk_indices:
                if idx < len(self.metadata):
                    self.metadata[idx]['removed'] = True
            
            print(f"βœ… Marked {len(chunk_indices)} chunks as removed for file {file_id}")
            return True
            
        except Exception as e:
            print(f"❌ Error removing file {file_id}: {e}")
            return False
    
    def save(self, path: str):
        """Save vector store to disk"""
        try:
            os.makedirs(path, exist_ok=True)
            
            # Save FAISS index
            faiss.write_index(self.index, os.path.join(path, "index.faiss"))
            
            # Save chunks and metadata
            data = {
                'chunks': self.chunks,
                'metadata': self.metadata,
                'file_map': self.file_map,
                'dimension': self.dimension
            }
            
            with open(os.path.join(path, "data.pkl"), 'wb') as f:
                pickle.dump(data, f)
            
            print(f"βœ… Vector store saved to {path}")
            
        except Exception as e:
            print(f"❌ Error saving vector store: {e}")
            raise
    
    def load(self, path: str) -> bool:
        """Load vector store from disk"""
        try:
            index_path = os.path.join(path, "index.faiss")
            data_path = os.path.join(path, "data.pkl")
            
            if not (os.path.exists(index_path) and os.path.exists(data_path)):
                print(f"Vector store files not found in {path}")
                return False
            
            # Load FAISS index
            self.index = faiss.read_index(index_path)
            
            # Load chunks and metadata
            with open(data_path, 'rb') as f:
                data = pickle.load(f)
            
            self.chunks = data.get('chunks', [])
            self.metadata = data.get('metadata', [])
            self.file_map = data.get('file_map', {})
            
            # Verify dimension consistency
            saved_dimension = data.get('dimension', self.dimension)
            if saved_dimension != self.dimension:
                print(f"Warning: Dimension mismatch. Expected: {self.dimension}, Got: {saved_dimension}")
            
            print(f"βœ… Vector store loaded from {path}. {len(self.chunks)} chunks, {len(self.file_map)} files")
            return True
            
        except Exception as e:
            print(f"❌ Error loading vector store: {e}")
            return False
    
    def reset(self):
        """Reset vector store (clear all data)"""
        try:
            # Reinitialize FAISS index
            self.index = faiss.IndexFlatIP(self.dimension)
            
            # Clear data
            self.chunks = []
            self.metadata = []
            self.file_map = {}
            
            print("βœ… Vector store reset successfully")
            
        except Exception as e:
            print(f"❌ Error resetting vector store: {e}")
            raise
    
    def get_chunk_by_index(self, index: int) -> Optional[Dict[str, Any]]:
        """Get chunk by global index"""
        if 0 <= index < len(self.chunks):
            return {
                'text': self.chunks[index],
                'metadata': self.metadata[index]
            }
        return None
    
    def search_by_file(self, file_id: str, query: str = "", k: int = 10) -> List[Dict[str, Any]]:
        """Get all chunks for a specific file, optionally filtered by query"""
        if file_id not in self.file_map:
            return []
        
        chunk_indices = self.file_map[file_id]
        results = []
        
        for idx in chunk_indices:
            if idx < len(self.chunks):
                # Skip removed chunks
                if self.metadata[idx].get('removed', False):
                    continue
                
                result = {
                    'text': self.chunks[idx],
                    'metadata': self.metadata[idx].copy(),
                    'score': 1.0,  # No scoring for file-based retrieval
                    'global_index': idx
                }
                results.append(result)
        
        # If query provided, filter results
        if query:
            # Simple text matching (can be enhanced with embedding similarity)
            query_lower = query.lower()
            filtered_results = []
            for result in results:
                if query_lower in result['text'].lower():
                    filtered_results.append(result)
            results = filtered_results
        
        return results[:k]
    
    def optimize_index(self):
        """Optimize FAISS index (placeholder for future enhancements)"""
        # For now, just print stats
        stats = self.get_document_stats()
        print(f"πŸ“Š Index stats: {stats['total_chunks']} chunks, {stats['total_files']} files")
        
        # In the future, we could:
        # - Remove deleted chunks and rebuild index
        # - Switch to more efficient index types (IVF, HNSW)
        # - Compress embeddings
        pass