Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| # Vector store | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from config import Config | |
| class VectorStore: | |
| """FAISS-based vector store for document embeddings""" | |
| def __init__(self, embedding_model: SentenceTransformer, config: Config = None): | |
| self.config = config or Config() | |
| self.embedding_model = embedding_model | |
| # Get embedding dimension | |
| self.dimension = embedding_model.get_sentence_embedding_dimension() | |
| # Initialize FAISS index | |
| self.index = faiss.IndexFlatIP(self.dimension) # Inner product for cosine similarity | |
| # Storage for chunks and metadata | |
| self.chunks = [] | |
| self.metadata = [] | |
| self.file_map = {} # Map file_id to chunk indices | |
| print(f"β Vector store initialized with dimension: {self.dimension}") | |
| def add_documents(self, chunks: List[str], file_id: str, filename: str): | |
| """Add documents to vector store""" | |
| if not chunks: | |
| print("Warning: No chunks to add") | |
| return | |
| print(f"π Adding {len(chunks)} chunks from {filename}") | |
| try: | |
| # Generate embeddings | |
| embeddings = self.embedding_model.encode( | |
| chunks, | |
| convert_to_numpy=True, | |
| show_progress_bar=len(chunks) > 10 | |
| ) | |
| # Ensure embeddings are float32 | |
| embeddings = embeddings.astype(np.float32) | |
| # Normalize embeddings for cosine similarity with inner product | |
| faiss.normalize_L2(embeddings) | |
| # Add to FAISS index | |
| start_idx = len(self.chunks) | |
| self.index.add(embeddings) | |
| # Store chunks and metadata | |
| chunk_indices = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_idx = start_idx + i | |
| chunk_indices.append(chunk_idx) | |
| self.chunks.append(chunk) | |
| self.metadata.append({ | |
| 'file_id': file_id, | |
| 'filename': filename, | |
| 'chunk_index': i, | |
| 'global_index': chunk_idx, | |
| 'text': chunk, | |
| 'embedding_added': True | |
| }) | |
| # Update file mapping | |
| if file_id not in self.file_map: | |
| self.file_map[file_id] = [] | |
| self.file_map[file_id].extend(chunk_indices) | |
| print(f"β Successfully added {len(chunks)} chunks. Total chunks: {len(self.chunks)}") | |
| except Exception as e: | |
| print(f"β Error adding documents: {e}") | |
| raise | |
| def search(self, query: str, k: int = 5, file_id: Optional[str] = None) -> List[Dict[str, Any]]: | |
| """Search for similar documents""" | |
| if len(self.chunks) == 0: | |
| print("Warning: No documents in vector store") | |
| return [] | |
| try: | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) | |
| query_embedding = query_embedding.astype(np.float32) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| search_k = min(k, len(self.chunks)) # Don't search for more than available | |
| scores, indices = self.index.search(query_embedding, search_k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx != -1 and idx < len(self.chunks): # Valid index | |
| # Filter by file_id if specified | |
| if file_id and self.metadata[idx]['file_id'] != file_id: | |
| continue | |
| result = { | |
| 'text': self.chunks[idx], | |
| 'metadata': self.metadata[idx].copy(), | |
| 'score': float(score), | |
| 'similarity': float(score) # Alias for compatibility | |
| } | |
| results.append(result) | |
| # Sort by score (highest first) | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| print(f"π Found {len(results)} results for query: '{query[:50]}...'") | |
| return results[:k] # Return top k results | |
| except Exception as e: | |
| print(f"β Search error: {e}") | |
| return [] | |
| def get_document_stats(self) -> Dict[str, Any]: | |
| """Get statistics about stored documents""" | |
| stats = { | |
| 'total_chunks': len(self.chunks), | |
| 'total_files': len(self.file_map), | |
| 'index_size': self.index.ntotal, | |
| 'dimension': self.dimension | |
| } | |
| # File-wise statistics | |
| file_stats = {} | |
| for file_id, chunk_indices in self.file_map.items(): | |
| filename = self.metadata[chunk_indices[0]]['filename'] if chunk_indices else 'unknown' | |
| file_stats[file_id] = { | |
| 'filename': filename, | |
| 'chunk_count': len(chunk_indices), | |
| 'chunk_indices': chunk_indices | |
| } | |
| stats['files'] = file_stats | |
| return stats | |
| def remove_file(self, file_id: str) -> bool: | |
| """Remove all chunks for a specific file""" | |
| if file_id not in self.file_map: | |
| print(f"Warning: File {file_id} not found in vector store") | |
| return False | |
| try: | |
| # Get chunk indices for this file | |
| chunk_indices = self.file_map[file_id] | |
| # Remove from file map | |
| del self.file_map[file_id] | |
| # Mark chunks as removed (we can't actually remove from FAISS index) | |
| for idx in chunk_indices: | |
| if idx < len(self.metadata): | |
| self.metadata[idx]['removed'] = True | |
| print(f"β Marked {len(chunk_indices)} chunks as removed for file {file_id}") | |
| return True | |
| except Exception as e: | |
| print(f"β Error removing file {file_id}: {e}") | |
| return False | |
| def save(self, path: str): | |
| """Save vector store to disk""" | |
| try: | |
| os.makedirs(path, exist_ok=True) | |
| # Save FAISS index | |
| faiss.write_index(self.index, os.path.join(path, "index.faiss")) | |
| # Save chunks and metadata | |
| data = { | |
| 'chunks': self.chunks, | |
| 'metadata': self.metadata, | |
| 'file_map': self.file_map, | |
| 'dimension': self.dimension | |
| } | |
| with open(os.path.join(path, "data.pkl"), 'wb') as f: | |
| pickle.dump(data, f) | |
| print(f"β Vector store saved to {path}") | |
| except Exception as e: | |
| print(f"β Error saving vector store: {e}") | |
| raise | |
| def load(self, path: str) -> bool: | |
| """Load vector store from disk""" | |
| try: | |
| index_path = os.path.join(path, "index.faiss") | |
| data_path = os.path.join(path, "data.pkl") | |
| if not (os.path.exists(index_path) and os.path.exists(data_path)): | |
| print(f"Vector store files not found in {path}") | |
| return False | |
| # Load FAISS index | |
| self.index = faiss.read_index(index_path) | |
| # Load chunks and metadata | |
| with open(data_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.chunks = data.get('chunks', []) | |
| self.metadata = data.get('metadata', []) | |
| self.file_map = data.get('file_map', {}) | |
| # Verify dimension consistency | |
| saved_dimension = data.get('dimension', self.dimension) | |
| if saved_dimension != self.dimension: | |
| print(f"Warning: Dimension mismatch. Expected: {self.dimension}, Got: {saved_dimension}") | |
| print(f"β Vector store loaded from {path}. {len(self.chunks)} chunks, {len(self.file_map)} files") | |
| return True | |
| except Exception as e: | |
| print(f"β Error loading vector store: {e}") | |
| return False | |
| def reset(self): | |
| """Reset vector store (clear all data)""" | |
| try: | |
| # Reinitialize FAISS index | |
| self.index = faiss.IndexFlatIP(self.dimension) | |
| # Clear data | |
| self.chunks = [] | |
| self.metadata = [] | |
| self.file_map = {} | |
| print("β Vector store reset successfully") | |
| except Exception as e: | |
| print(f"β Error resetting vector store: {e}") | |
| raise | |
| def get_chunk_by_index(self, index: int) -> Optional[Dict[str, Any]]: | |
| """Get chunk by global index""" | |
| if 0 <= index < len(self.chunks): | |
| return { | |
| 'text': self.chunks[index], | |
| 'metadata': self.metadata[index] | |
| } | |
| return None | |
| def search_by_file(self, file_id: str, query: str = "", k: int = 10) -> List[Dict[str, Any]]: | |
| """Get all chunks for a specific file, optionally filtered by query""" | |
| if file_id not in self.file_map: | |
| return [] | |
| chunk_indices = self.file_map[file_id] | |
| results = [] | |
| for idx in chunk_indices: | |
| if idx < len(self.chunks): | |
| # Skip removed chunks | |
| if self.metadata[idx].get('removed', False): | |
| continue | |
| result = { | |
| 'text': self.chunks[idx], | |
| 'metadata': self.metadata[idx].copy(), | |
| 'score': 1.0, # No scoring for file-based retrieval | |
| 'global_index': idx | |
| } | |
| results.append(result) | |
| # If query provided, filter results | |
| if query: | |
| # Simple text matching (can be enhanced with embedding similarity) | |
| query_lower = query.lower() | |
| filtered_results = [] | |
| for result in results: | |
| if query_lower in result['text'].lower(): | |
| filtered_results.append(result) | |
| results = filtered_results | |
| return results[:k] | |
| def optimize_index(self): | |
| """Optimize FAISS index (placeholder for future enhancements)""" | |
| # For now, just print stats | |
| stats = self.get_document_stats() | |
| print(f"π Index stats: {stats['total_chunks']} chunks, {stats['total_files']} files") | |
| # In the future, we could: | |
| # - Remove deleted chunks and rebuild index | |
| # - Switch to more efficient index types (IVF, HNSW) | |
| # - Compress embeddings | |
| pass |