# processing/faiss_manager.py """ FAISS + SQLite vector database implementation High performance local vector search """ import faiss import numpy as np import sqlite3 import json import pickle from typing import List, Dict, Any, Optional import os from embeddings.embedding_models import EmbeddingManager from embeddings.text_chunking import ResearchPaperChunker class FaissManager: """FAISS + SQLite vector database manager""" def __init__(self, faiss_index_path: str = "./data/vector_db/faiss/index.faiss", sqlite_db_path: str = "./data/vector_db/faiss/metadata.db", embedding_model: str = "all-mpnet-base-v2", chunk_strategy: str = "semantic", index_type: str = "IVFFlat"): self.faiss_index_path = faiss_index_path self.sqlite_db_path = sqlite_db_path self.embedding_manager = EmbeddingManager(embedding_model) self.chunker = ResearchPaperChunker(chunk_strategy) self.index_type = index_type # Create directories if they don't exist os.makedirs(os.path.dirname(faiss_index_path), exist_ok=True) os.makedirs(os.path.dirname(sqlite_db_path), exist_ok=True) # Initialize FAISS index and SQLite database self.index = None self.dimension = self.embedding_manager.get_embedding_dimensions() self._initialize_faiss_index() self._initialize_sqlite_db() print(f"โœ… FAISS+SQLite initialized: {faiss_index_path}") def _initialize_faiss_index(self): """Initialize or load FAISS index""" try: if os.path.exists(self.faiss_index_path): print("๐Ÿ“‚ Loading existing FAISS index...") self.index = faiss.read_index(self.faiss_index_path) else: print("๐Ÿ†• Creating new FAISS index...") if self.index_type == "IVFFlat": # Create IVF index for faster search (requires training) quantizer = faiss.IndexFlatIP(self.dimension) self.index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) self.index.nprobe = 10 # Number of clusters to search else: # Default to flat index (exact search) self.index = faiss.IndexFlatIP(self.dimension) print(f"โœ… FAISS index created: {self.index_type}") except Exception as e: print(f"โŒ Error initializing FAISS index: {e}") # Fallback to flat index self.index = faiss.IndexFlatIP(self.dimension) def _initialize_sqlite_db(self): """Initialize SQLite database for metadata""" try: self.conn = sqlite3.connect(self.sqlite_db_path) cursor = self.conn.cursor() # Create tables cursor.execute(''' CREATE TABLE IF NOT EXISTS chunks ( id INTEGER PRIMARY KEY AUTOINCREMENT, chunk_id TEXT UNIQUE, paper_id TEXT, paper_title TEXT, text_content TEXT, source TEXT, domain TEXT, publication_date TEXT, chunk_strategy TEXT, start_char INTEGER, end_char INTEGER, embedding_index INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS papers ( paper_id TEXT PRIMARY KEY, title TEXT, abstract TEXT, source TEXT, domain TEXT, publication_date TEXT, authors TEXT, total_chunks INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') self.conn.commit() print("โœ… SQLite database initialized") except Exception as e: print(f"โŒ Error initializing SQLite database: {e}") raise def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool: """Add papers to FAISS + SQLite""" try: # Chunk all papers all_chunks = self.chunker.batch_chunk_papers(papers) if not all_chunks: print("โš ๏ธ No chunks generated from papers") return False # Prepare embeddings and metadata chunk_texts = [chunk['text'] for chunk in all_chunks] embeddings = self.embedding_manager.encode(chunk_texts) # Convert to numpy array and normalize for cosine similarity embeddings = np.array(embeddings).astype('float32') faiss.normalize_L2(embeddings) # Train index if it's IVF and not trained yet if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained: print("๐Ÿ”ง Training FAISS index...") self.index.train(embeddings) # Add to FAISS index start_index = self.index.ntotal if hasattr(self.index, 'ntotal') else 0 self.index.add(embeddings) # Add to SQLite database cursor = self.conn.cursor() for i, chunk in enumerate(all_chunks): embedding_index = start_index + i cursor.execute(''' INSERT OR REPLACE INTO chunks (chunk_id, paper_id, paper_title, text_content, source, domain, publication_date, chunk_strategy, start_char, end_char, embedding_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( f"{chunk['paper_id']}_chunk_{i}", chunk['paper_id'], chunk['paper_title'], chunk['text'], chunk['source'], chunk['domain'], chunk.get('publication_date', ''), chunk.get('chunk_strategy', 'semantic'), chunk.get('start_char', 0), chunk.get('end_char', 0), embedding_index )) # Update paper records for paper in papers: paper_chunks = [c for c in all_chunks if c['paper_id'] == paper['id']] cursor.execute(''' INSERT OR REPLACE INTO papers (paper_id, title, abstract, source, domain, publication_date, authors, total_chunks) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( paper['id'], paper['title'], paper.get('abstract', ''), paper.get('source', ''), paper.get('domain', ''), paper.get('publication_date', ''), json.dumps(paper.get('authors', [])), len(paper_chunks) )) self.conn.commit() # Save FAISS index faiss.write_index(self.index, self.faiss_index_path) print(f"โœ… Added {len(all_chunks)} chunks from {len(papers)} papers") return True except Exception as e: print(f"โŒ Error adding papers to FAISS: {e}") self.conn.rollback() return False def search(self, query: str, domain: str = None, n_results: int = 10) -> List[Dict[str, Any]]: """Search for similar paper chunks""" try: # Encode query query_embedding = self.embedding_manager.encode([query]) query_embedding = np.array(query_embedding).astype('float32') faiss.normalize_L2(query_embedding) # Search FAISS index distances, indices = self.index.search(query_embedding, n_results * 2) # Get extra for filtering # Get metadata from SQLite cursor = self.conn.cursor() placeholders = ','.join('?' for _ in indices[0]) domain_filter = "AND domain = ?" if domain else "" params = list(indices[0]) + ([domain] if domain else []) cursor.execute(f''' SELECT c.chunk_id, c.paper_id, c.paper_title, c.text_content, c.source, c.domain, c.publication_date, c.chunk_strategy, c.embedding_index FROM chunks c WHERE c.embedding_index IN ({placeholders}) {domain_filter} ORDER BY c.embedding_index ''', params) results = cursor.fetchall() # Format results with distances formatted_results = [] for row in results: chunk_id, paper_id, paper_title, text_content, source, domain, pub_date, chunk_strategy, embedding_index = row # Find the distance for this index distance_idx = np.where(indices[0] == embedding_index)[0] if len(distance_idx) > 0: distance = float(distances[0][distance_idx[0]]) formatted_results.append({ 'text': text_content, 'metadata': { 'paper_id': paper_id, 'paper_title': paper_title, 'source': source, 'domain': domain, 'publication_date': pub_date, 'chunk_strategy': chunk_strategy, 'embedding_index': embedding_index }, 'distance': distance, 'id': chunk_id }) # Sort by distance and take top n_results formatted_results.sort(key=lambda x: x['distance'], reverse=True) # Higher distance = more similar in cosine return formatted_results[:n_results] except Exception as e: print(f"โŒ FAISS search error: {e}") return [] def get_collection_stats(self) -> Dict[str, Any]: """Get statistics about the collection""" try: cursor = self.conn.cursor() # Get chunk count cursor.execute("SELECT COUNT(*) FROM chunks") total_chunks = cursor.fetchone()[0] # Get paper count cursor.execute("SELECT COUNT(*) FROM papers") total_papers = cursor.fetchone()[0] # Get domain distribution cursor.execute("SELECT domain, COUNT(*) FROM chunks GROUP BY domain") domain_distribution = dict(cursor.fetchall()) return { "total_chunks": total_chunks, "total_papers": total_papers, "domain_distribution": domain_distribution, "faiss_index_size": self.index.ntotal if hasattr(self.index, 'ntotal') else 0, "embedding_model": self.embedding_manager.model_name, "index_type": self.index_type } except Exception as e: print(f"โŒ Error getting collection stats: {e}") return {} def delete_paper(self, paper_id: str) -> bool: """Delete all chunks for a specific paper""" try: cursor = self.conn.cursor() # Get embedding indices to remove from FAISS cursor.execute("SELECT embedding_index FROM chunks WHERE paper_id = ?", (paper_id,)) indices_to_remove = [row[0] for row in cursor.fetchall()] if indices_to_remove: # Remove from FAISS (this is complex in FAISS, we'll rebuild for simplicity) self._rebuild_index_without_indices(indices_to_remove) # Remove from SQLite cursor.execute("DELETE FROM chunks WHERE paper_id = ?", (paper_id,)) cursor.execute("DELETE FROM papers WHERE paper_id = ?", (paper_id,)) self.conn.commit() print(f"โœ… Deleted {len(indices_to_remove)} chunks for paper {paper_id}") return True else: print(f"โš ๏ธ No chunks found for paper {paper_id}") return False except Exception as e: print(f"โŒ Error deleting paper {paper_id}: {e}") self.conn.rollback() return False def _rebuild_index_without_indices(self, indices_to_remove: List[int]): """Rebuild FAISS index without specific indices""" try: # This is a simplified approach - in production you'd want a more efficient method print("๐Ÿ”ง Rebuilding FAISS index...") # Get all current chunks cursor = self.conn.cursor() cursor.execute("SELECT embedding_index FROM chunks ORDER BY embedding_index") all_indices = [row[0] for row in cursor.fetchall()] # Reconstruct embeddings (this is memory intensive) remaining_embeddings = [] for idx in all_indices: if idx not in indices_to_remove: # In a real implementation, you'd store embeddings separately # For now, we'll skip this complex operation pass # For now, we'll just note that a rebuild is needed print("โš ๏ธ FAISS index needs manual rebuild after deletions") except Exception as e: print(f"โŒ Error rebuilding FAISS index: {e}") def __del__(self): """Cleanup on destruction""" if hasattr(self, 'conn'): self.conn.close() # Quick test def test_faiss_manager(): """Test FAISS manager""" test_papers = [ { 'id': 'test_001', 'title': 'AI in Medical Imaging', 'abstract': 'Deep learning transforms medical image analysis with improved accuracy.', 'source': 'test', 'domain': 'medical_imaging', 'authors': ['John Doe', 'Jane Smith'] }, { 'id': 'test_002', 'title': 'Genomics and Machine Learning', 'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.', 'source': 'test', 'domain': 'genomics', 'authors': ['Alan Turing'] } ] print("๐Ÿงช Testing FAISS Manager") print("=" * 50) try: manager = FaissManager( faiss_index_path="./data/test_faiss/index.faiss", sqlite_db_path="./data/test_faiss/metadata.db", index_type="Flat" # Use flat for testing (no training needed) ) # Add test papers success = manager.add_papers(test_papers) if success: print("โœ… Papers added successfully") # Test search results = manager.search("medical image analysis", n_results=5) print(f"๐Ÿ” Search results: {len(results)} chunks found") for result in results[:2]: print(f" - {result['metadata']['paper_title']} (distance: {result['distance']:.3f})") # Get stats stats = manager.get_collection_stats() print(f"๐Ÿ“Š Collection stats: {stats}") else: print("โŒ Failed to add papers") except Exception as e: print(f"โŒ FAISS test failed: {e}") if __name__ == "__main__": test_faiss_manager()