Spaces:
Running
Running
| # processing/faiss_manager.py | |
| """ | |
| FAISS + SQLite vector database implementation | |
| High performance local vector search | |
| """ | |
| import faiss | |
| import numpy as np | |
| import sqlite3 | |
| import json | |
| import pickle | |
| from typing import List, Dict, Any, Optional | |
| import os | |
| from embeddings.embedding_models import EmbeddingManager | |
| from embeddings.text_chunking import ResearchPaperChunker | |
| class FaissManager: | |
| """FAISS + SQLite vector database manager""" | |
| def __init__(self, | |
| faiss_index_path: str = "./data/vector_db/faiss/index.faiss", | |
| sqlite_db_path: str = "./data/vector_db/faiss/metadata.db", | |
| embedding_model: str = "all-mpnet-base-v2", | |
| chunk_strategy: str = "semantic", | |
| index_type: str = "IVFFlat"): | |
| self.faiss_index_path = faiss_index_path | |
| self.sqlite_db_path = sqlite_db_path | |
| self.embedding_manager = EmbeddingManager(embedding_model) | |
| self.chunker = ResearchPaperChunker(chunk_strategy) | |
| self.index_type = index_type | |
| # Create directories if they don't exist | |
| os.makedirs(os.path.dirname(faiss_index_path), exist_ok=True) | |
| os.makedirs(os.path.dirname(sqlite_db_path), exist_ok=True) | |
| # Initialize FAISS index and SQLite database | |
| self.index = None | |
| self.dimension = self.embedding_manager.get_embedding_dimensions() | |
| self._initialize_faiss_index() | |
| self._initialize_sqlite_db() | |
| print(f"β FAISS+SQLite initialized: {faiss_index_path}") | |
| def _initialize_faiss_index(self): | |
| """Initialize or load FAISS index""" | |
| try: | |
| if os.path.exists(self.faiss_index_path): | |
| print("π Loading existing FAISS index...") | |
| self.index = faiss.read_index(self.faiss_index_path) | |
| else: | |
| print("π Creating new FAISS index...") | |
| if self.index_type == "IVFFlat": | |
| # Create IVF index for faster search (requires training) | |
| quantizer = faiss.IndexFlatIP(self.dimension) | |
| self.index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) | |
| self.index.nprobe = 10 # Number of clusters to search | |
| else: | |
| # Default to flat index (exact search) | |
| self.index = faiss.IndexFlatIP(self.dimension) | |
| print(f"β FAISS index created: {self.index_type}") | |
| except Exception as e: | |
| print(f"β Error initializing FAISS index: {e}") | |
| # Fallback to flat index | |
| self.index = faiss.IndexFlatIP(self.dimension) | |
| def _initialize_sqlite_db(self): | |
| """Initialize SQLite database for metadata""" | |
| try: | |
| self.conn = sqlite3.connect(self.sqlite_db_path) | |
| cursor = self.conn.cursor() | |
| # Create tables | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS chunks ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| chunk_id TEXT UNIQUE, | |
| paper_id TEXT, | |
| paper_title TEXT, | |
| text_content TEXT, | |
| source TEXT, | |
| domain TEXT, | |
| publication_date TEXT, | |
| chunk_strategy TEXT, | |
| start_char INTEGER, | |
| end_char INTEGER, | |
| embedding_index INTEGER, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS papers ( | |
| paper_id TEXT PRIMARY KEY, | |
| title TEXT, | |
| abstract TEXT, | |
| source TEXT, | |
| domain TEXT, | |
| publication_date TEXT, | |
| authors TEXT, | |
| total_chunks INTEGER, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| self.conn.commit() | |
| print("β SQLite database initialized") | |
| except Exception as e: | |
| print(f"β Error initializing SQLite database: {e}") | |
| raise | |
| def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool: | |
| """Add papers to FAISS + SQLite""" | |
| try: | |
| # Chunk all papers | |
| all_chunks = self.chunker.batch_chunk_papers(papers) | |
| if not all_chunks: | |
| print("β οΈ No chunks generated from papers") | |
| return False | |
| # Prepare embeddings and metadata | |
| chunk_texts = [chunk['text'] for chunk in all_chunks] | |
| embeddings = self.embedding_manager.encode(chunk_texts) | |
| # Convert to numpy array and normalize for cosine similarity | |
| embeddings = np.array(embeddings).astype('float32') | |
| faiss.normalize_L2(embeddings) | |
| # Train index if it's IVF and not trained yet | |
| if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained: | |
| print("π§ Training FAISS index...") | |
| self.index.train(embeddings) | |
| # Add to FAISS index | |
| start_index = self.index.ntotal if hasattr(self.index, 'ntotal') else 0 | |
| self.index.add(embeddings) | |
| # Add to SQLite database | |
| cursor = self.conn.cursor() | |
| for i, chunk in enumerate(all_chunks): | |
| embedding_index = start_index + i | |
| cursor.execute(''' | |
| INSERT OR REPLACE INTO chunks | |
| (chunk_id, paper_id, paper_title, text_content, source, domain, | |
| publication_date, chunk_strategy, start_char, end_char, embedding_index) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| ''', ( | |
| f"{chunk['paper_id']}_chunk_{i}", | |
| chunk['paper_id'], | |
| chunk['paper_title'], | |
| chunk['text'], | |
| chunk['source'], | |
| chunk['domain'], | |
| chunk.get('publication_date', ''), | |
| chunk.get('chunk_strategy', 'semantic'), | |
| chunk.get('start_char', 0), | |
| chunk.get('end_char', 0), | |
| embedding_index | |
| )) | |
| # Update paper records | |
| for paper in papers: | |
| paper_chunks = [c for c in all_chunks if c['paper_id'] == paper['id']] | |
| cursor.execute(''' | |
| INSERT OR REPLACE INTO papers | |
| (paper_id, title, abstract, source, domain, publication_date, authors, total_chunks) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
| ''', ( | |
| paper['id'], | |
| paper['title'], | |
| paper.get('abstract', ''), | |
| paper.get('source', ''), | |
| paper.get('domain', ''), | |
| paper.get('publication_date', ''), | |
| json.dumps(paper.get('authors', [])), | |
| len(paper_chunks) | |
| )) | |
| self.conn.commit() | |
| # Save FAISS index | |
| faiss.write_index(self.index, self.faiss_index_path) | |
| print(f"β Added {len(all_chunks)} chunks from {len(papers)} papers") | |
| return True | |
| except Exception as e: | |
| print(f"β Error adding papers to FAISS: {e}") | |
| self.conn.rollback() | |
| return False | |
| def search(self, | |
| query: str, | |
| domain: str = None, | |
| n_results: int = 10) -> List[Dict[str, Any]]: | |
| """Search for similar paper chunks""" | |
| try: | |
| # Encode query | |
| query_embedding = self.embedding_manager.encode([query]) | |
| query_embedding = np.array(query_embedding).astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| # Search FAISS index | |
| distances, indices = self.index.search(query_embedding, n_results * 2) # Get extra for filtering | |
| # Get metadata from SQLite | |
| cursor = self.conn.cursor() | |
| placeholders = ','.join('?' for _ in indices[0]) | |
| domain_filter = "AND domain = ?" if domain else "" | |
| params = list(indices[0]) + ([domain] if domain else []) | |
| cursor.execute(f''' | |
| SELECT c.chunk_id, c.paper_id, c.paper_title, c.text_content, | |
| c.source, c.domain, c.publication_date, c.chunk_strategy, | |
| c.embedding_index | |
| FROM chunks c | |
| WHERE c.embedding_index IN ({placeholders}) {domain_filter} | |
| ORDER BY c.embedding_index | |
| ''', params) | |
| results = cursor.fetchall() | |
| # Format results with distances | |
| formatted_results = [] | |
| for row in results: | |
| chunk_id, paper_id, paper_title, text_content, source, domain, pub_date, chunk_strategy, embedding_index = row | |
| # Find the distance for this index | |
| distance_idx = np.where(indices[0] == embedding_index)[0] | |
| if len(distance_idx) > 0: | |
| distance = float(distances[0][distance_idx[0]]) | |
| formatted_results.append({ | |
| 'text': text_content, | |
| 'metadata': { | |
| 'paper_id': paper_id, | |
| 'paper_title': paper_title, | |
| 'source': source, | |
| 'domain': domain, | |
| 'publication_date': pub_date, | |
| 'chunk_strategy': chunk_strategy, | |
| 'embedding_index': embedding_index | |
| }, | |
| 'distance': distance, | |
| 'id': chunk_id | |
| }) | |
| # Sort by distance and take top n_results | |
| formatted_results.sort(key=lambda x: x['distance'], | |
| reverse=True) # Higher distance = more similar in cosine | |
| return formatted_results[:n_results] | |
| except Exception as e: | |
| print(f"β FAISS search error: {e}") | |
| return [] | |
| def get_collection_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the collection""" | |
| try: | |
| cursor = self.conn.cursor() | |
| # Get chunk count | |
| cursor.execute("SELECT COUNT(*) FROM chunks") | |
| total_chunks = cursor.fetchone()[0] | |
| # Get paper count | |
| cursor.execute("SELECT COUNT(*) FROM papers") | |
| total_papers = cursor.fetchone()[0] | |
| # Get domain distribution | |
| cursor.execute("SELECT domain, COUNT(*) FROM chunks GROUP BY domain") | |
| domain_distribution = dict(cursor.fetchall()) | |
| return { | |
| "total_chunks": total_chunks, | |
| "total_papers": total_papers, | |
| "domain_distribution": domain_distribution, | |
| "faiss_index_size": self.index.ntotal if hasattr(self.index, 'ntotal') else 0, | |
| "embedding_model": self.embedding_manager.model_name, | |
| "index_type": self.index_type | |
| } | |
| except Exception as e: | |
| print(f"β Error getting collection stats: {e}") | |
| return {} | |
| def delete_paper(self, paper_id: str) -> bool: | |
| """Delete all chunks for a specific paper""" | |
| try: | |
| cursor = self.conn.cursor() | |
| # Get embedding indices to remove from FAISS | |
| cursor.execute("SELECT embedding_index FROM chunks WHERE paper_id = ?", (paper_id,)) | |
| indices_to_remove = [row[0] for row in cursor.fetchall()] | |
| if indices_to_remove: | |
| # Remove from FAISS (this is complex in FAISS, we'll rebuild for simplicity) | |
| self._rebuild_index_without_indices(indices_to_remove) | |
| # Remove from SQLite | |
| cursor.execute("DELETE FROM chunks WHERE paper_id = ?", (paper_id,)) | |
| cursor.execute("DELETE FROM papers WHERE paper_id = ?", (paper_id,)) | |
| self.conn.commit() | |
| print(f"β Deleted {len(indices_to_remove)} chunks for paper {paper_id}") | |
| return True | |
| else: | |
| print(f"β οΈ No chunks found for paper {paper_id}") | |
| return False | |
| except Exception as e: | |
| print(f"β Error deleting paper {paper_id}: {e}") | |
| self.conn.rollback() | |
| return False | |
| def _rebuild_index_without_indices(self, indices_to_remove: List[int]): | |
| """Rebuild FAISS index without specific indices""" | |
| try: | |
| # This is a simplified approach - in production you'd want a more efficient method | |
| print("π§ Rebuilding FAISS index...") | |
| # Get all current chunks | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT embedding_index FROM chunks ORDER BY embedding_index") | |
| all_indices = [row[0] for row in cursor.fetchall()] | |
| # Reconstruct embeddings (this is memory intensive) | |
| remaining_embeddings = [] | |
| for idx in all_indices: | |
| if idx not in indices_to_remove: | |
| # In a real implementation, you'd store embeddings separately | |
| # For now, we'll skip this complex operation | |
| pass | |
| # For now, we'll just note that a rebuild is needed | |
| print("β οΈ FAISS index needs manual rebuild after deletions") | |
| except Exception as e: | |
| print(f"β Error rebuilding FAISS index: {e}") | |
| def __del__(self): | |
| """Cleanup on destruction""" | |
| if hasattr(self, 'conn'): | |
| self.conn.close() | |
| # Quick test | |
| def test_faiss_manager(): | |
| """Test FAISS manager""" | |
| test_papers = [ | |
| { | |
| 'id': 'test_001', | |
| 'title': 'AI in Medical Imaging', | |
| 'abstract': 'Deep learning transforms medical image analysis with improved accuracy.', | |
| 'source': 'test', | |
| 'domain': 'medical_imaging', | |
| 'authors': ['John Doe', 'Jane Smith'] | |
| }, | |
| { | |
| 'id': 'test_002', | |
| 'title': 'Genomics and Machine Learning', | |
| 'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.', | |
| 'source': 'test', | |
| 'domain': 'genomics', | |
| 'authors': ['Alan Turing'] | |
| } | |
| ] | |
| print("π§ͺ Testing FAISS Manager") | |
| print("=" * 50) | |
| try: | |
| manager = FaissManager( | |
| faiss_index_path="./data/test_faiss/index.faiss", | |
| sqlite_db_path="./data/test_faiss/metadata.db", | |
| index_type="Flat" # Use flat for testing (no training needed) | |
| ) | |
| # Add test papers | |
| success = manager.add_papers(test_papers) | |
| if success: | |
| print("β Papers added successfully") | |
| # Test search | |
| results = manager.search("medical image analysis", n_results=5) | |
| print(f"π Search results: {len(results)} chunks found") | |
| for result in results[:2]: | |
| print(f" - {result['metadata']['paper_title']} (distance: {result['distance']:.3f})") | |
| # Get stats | |
| stats = manager.get_collection_stats() | |
| print(f"π Collection stats: {stats}") | |
| else: | |
| print("β Failed to add papers") | |
| except Exception as e: | |
| print(f"β FAISS test failed: {e}") | |
| if __name__ == "__main__": | |
| test_faiss_manager() |