MedSearchPro / processing /faiss_manager.py
paulhemb's picture
Initial Backend Deployment
1367957
# processing/faiss_manager.py
"""
FAISS + SQLite vector database implementation
High performance local vector search
"""
import faiss
import numpy as np
import sqlite3
import json
import pickle
from typing import List, Dict, Any, Optional
import os
from embeddings.embedding_models import EmbeddingManager
from embeddings.text_chunking import ResearchPaperChunker
class FaissManager:
"""FAISS + SQLite vector database manager"""
def __init__(self,
faiss_index_path: str = "./data/vector_db/faiss/index.faiss",
sqlite_db_path: str = "./data/vector_db/faiss/metadata.db",
embedding_model: str = "all-mpnet-base-v2",
chunk_strategy: str = "semantic",
index_type: str = "IVFFlat"):
self.faiss_index_path = faiss_index_path
self.sqlite_db_path = sqlite_db_path
self.embedding_manager = EmbeddingManager(embedding_model)
self.chunker = ResearchPaperChunker(chunk_strategy)
self.index_type = index_type
# Create directories if they don't exist
os.makedirs(os.path.dirname(faiss_index_path), exist_ok=True)
os.makedirs(os.path.dirname(sqlite_db_path), exist_ok=True)
# Initialize FAISS index and SQLite database
self.index = None
self.dimension = self.embedding_manager.get_embedding_dimensions()
self._initialize_faiss_index()
self._initialize_sqlite_db()
print(f"βœ… FAISS+SQLite initialized: {faiss_index_path}")
def _initialize_faiss_index(self):
"""Initialize or load FAISS index"""
try:
if os.path.exists(self.faiss_index_path):
print("πŸ“‚ Loading existing FAISS index...")
self.index = faiss.read_index(self.faiss_index_path)
else:
print("πŸ†• Creating new FAISS index...")
if self.index_type == "IVFFlat":
# Create IVF index for faster search (requires training)
quantizer = faiss.IndexFlatIP(self.dimension)
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, 100)
self.index.nprobe = 10 # Number of clusters to search
else:
# Default to flat index (exact search)
self.index = faiss.IndexFlatIP(self.dimension)
print(f"βœ… FAISS index created: {self.index_type}")
except Exception as e:
print(f"❌ Error initializing FAISS index: {e}")
# Fallback to flat index
self.index = faiss.IndexFlatIP(self.dimension)
def _initialize_sqlite_db(self):
"""Initialize SQLite database for metadata"""
try:
self.conn = sqlite3.connect(self.sqlite_db_path)
cursor = self.conn.cursor()
# Create tables
cursor.execute('''
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id TEXT UNIQUE,
paper_id TEXT,
paper_title TEXT,
text_content TEXT,
source TEXT,
domain TEXT,
publication_date TEXT,
chunk_strategy TEXT,
start_char INTEGER,
end_char INTEGER,
embedding_index INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS papers (
paper_id TEXT PRIMARY KEY,
title TEXT,
abstract TEXT,
source TEXT,
domain TEXT,
publication_date TEXT,
authors TEXT,
total_chunks INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
self.conn.commit()
print("βœ… SQLite database initialized")
except Exception as e:
print(f"❌ Error initializing SQLite database: {e}")
raise
def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool:
"""Add papers to FAISS + SQLite"""
try:
# Chunk all papers
all_chunks = self.chunker.batch_chunk_papers(papers)
if not all_chunks:
print("⚠️ No chunks generated from papers")
return False
# Prepare embeddings and metadata
chunk_texts = [chunk['text'] for chunk in all_chunks]
embeddings = self.embedding_manager.encode(chunk_texts)
# Convert to numpy array and normalize for cosine similarity
embeddings = np.array(embeddings).astype('float32')
faiss.normalize_L2(embeddings)
# Train index if it's IVF and not trained yet
if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained:
print("πŸ”§ Training FAISS index...")
self.index.train(embeddings)
# Add to FAISS index
start_index = self.index.ntotal if hasattr(self.index, 'ntotal') else 0
self.index.add(embeddings)
# Add to SQLite database
cursor = self.conn.cursor()
for i, chunk in enumerate(all_chunks):
embedding_index = start_index + i
cursor.execute('''
INSERT OR REPLACE INTO chunks
(chunk_id, paper_id, paper_title, text_content, source, domain,
publication_date, chunk_strategy, start_char, end_char, embedding_index)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
f"{chunk['paper_id']}_chunk_{i}",
chunk['paper_id'],
chunk['paper_title'],
chunk['text'],
chunk['source'],
chunk['domain'],
chunk.get('publication_date', ''),
chunk.get('chunk_strategy', 'semantic'),
chunk.get('start_char', 0),
chunk.get('end_char', 0),
embedding_index
))
# Update paper records
for paper in papers:
paper_chunks = [c for c in all_chunks if c['paper_id'] == paper['id']]
cursor.execute('''
INSERT OR REPLACE INTO papers
(paper_id, title, abstract, source, domain, publication_date, authors, total_chunks)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
paper['id'],
paper['title'],
paper.get('abstract', ''),
paper.get('source', ''),
paper.get('domain', ''),
paper.get('publication_date', ''),
json.dumps(paper.get('authors', [])),
len(paper_chunks)
))
self.conn.commit()
# Save FAISS index
faiss.write_index(self.index, self.faiss_index_path)
print(f"βœ… Added {len(all_chunks)} chunks from {len(papers)} papers")
return True
except Exception as e:
print(f"❌ Error adding papers to FAISS: {e}")
self.conn.rollback()
return False
def search(self,
query: str,
domain: str = None,
n_results: int = 10) -> List[Dict[str, Any]]:
"""Search for similar paper chunks"""
try:
# Encode query
query_embedding = self.embedding_manager.encode([query])
query_embedding = np.array(query_embedding).astype('float32')
faiss.normalize_L2(query_embedding)
# Search FAISS index
distances, indices = self.index.search(query_embedding, n_results * 2) # Get extra for filtering
# Get metadata from SQLite
cursor = self.conn.cursor()
placeholders = ','.join('?' for _ in indices[0])
domain_filter = "AND domain = ?" if domain else ""
params = list(indices[0]) + ([domain] if domain else [])
cursor.execute(f'''
SELECT c.chunk_id, c.paper_id, c.paper_title, c.text_content,
c.source, c.domain, c.publication_date, c.chunk_strategy,
c.embedding_index
FROM chunks c
WHERE c.embedding_index IN ({placeholders}) {domain_filter}
ORDER BY c.embedding_index
''', params)
results = cursor.fetchall()
# Format results with distances
formatted_results = []
for row in results:
chunk_id, paper_id, paper_title, text_content, source, domain, pub_date, chunk_strategy, embedding_index = row
# Find the distance for this index
distance_idx = np.where(indices[0] == embedding_index)[0]
if len(distance_idx) > 0:
distance = float(distances[0][distance_idx[0]])
formatted_results.append({
'text': text_content,
'metadata': {
'paper_id': paper_id,
'paper_title': paper_title,
'source': source,
'domain': domain,
'publication_date': pub_date,
'chunk_strategy': chunk_strategy,
'embedding_index': embedding_index
},
'distance': distance,
'id': chunk_id
})
# Sort by distance and take top n_results
formatted_results.sort(key=lambda x: x['distance'],
reverse=True) # Higher distance = more similar in cosine
return formatted_results[:n_results]
except Exception as e:
print(f"❌ FAISS search error: {e}")
return []
def get_collection_stats(self) -> Dict[str, Any]:
"""Get statistics about the collection"""
try:
cursor = self.conn.cursor()
# Get chunk count
cursor.execute("SELECT COUNT(*) FROM chunks")
total_chunks = cursor.fetchone()[0]
# Get paper count
cursor.execute("SELECT COUNT(*) FROM papers")
total_papers = cursor.fetchone()[0]
# Get domain distribution
cursor.execute("SELECT domain, COUNT(*) FROM chunks GROUP BY domain")
domain_distribution = dict(cursor.fetchall())
return {
"total_chunks": total_chunks,
"total_papers": total_papers,
"domain_distribution": domain_distribution,
"faiss_index_size": self.index.ntotal if hasattr(self.index, 'ntotal') else 0,
"embedding_model": self.embedding_manager.model_name,
"index_type": self.index_type
}
except Exception as e:
print(f"❌ Error getting collection stats: {e}")
return {}
def delete_paper(self, paper_id: str) -> bool:
"""Delete all chunks for a specific paper"""
try:
cursor = self.conn.cursor()
# Get embedding indices to remove from FAISS
cursor.execute("SELECT embedding_index FROM chunks WHERE paper_id = ?", (paper_id,))
indices_to_remove = [row[0] for row in cursor.fetchall()]
if indices_to_remove:
# Remove from FAISS (this is complex in FAISS, we'll rebuild for simplicity)
self._rebuild_index_without_indices(indices_to_remove)
# Remove from SQLite
cursor.execute("DELETE FROM chunks WHERE paper_id = ?", (paper_id,))
cursor.execute("DELETE FROM papers WHERE paper_id = ?", (paper_id,))
self.conn.commit()
print(f"βœ… Deleted {len(indices_to_remove)} chunks for paper {paper_id}")
return True
else:
print(f"⚠️ No chunks found for paper {paper_id}")
return False
except Exception as e:
print(f"❌ Error deleting paper {paper_id}: {e}")
self.conn.rollback()
return False
def _rebuild_index_without_indices(self, indices_to_remove: List[int]):
"""Rebuild FAISS index without specific indices"""
try:
# This is a simplified approach - in production you'd want a more efficient method
print("πŸ”§ Rebuilding FAISS index...")
# Get all current chunks
cursor = self.conn.cursor()
cursor.execute("SELECT embedding_index FROM chunks ORDER BY embedding_index")
all_indices = [row[0] for row in cursor.fetchall()]
# Reconstruct embeddings (this is memory intensive)
remaining_embeddings = []
for idx in all_indices:
if idx not in indices_to_remove:
# In a real implementation, you'd store embeddings separately
# For now, we'll skip this complex operation
pass
# For now, we'll just note that a rebuild is needed
print("⚠️ FAISS index needs manual rebuild after deletions")
except Exception as e:
print(f"❌ Error rebuilding FAISS index: {e}")
def __del__(self):
"""Cleanup on destruction"""
if hasattr(self, 'conn'):
self.conn.close()
# Quick test
def test_faiss_manager():
"""Test FAISS manager"""
test_papers = [
{
'id': 'test_001',
'title': 'AI in Medical Imaging',
'abstract': 'Deep learning transforms medical image analysis with improved accuracy.',
'source': 'test',
'domain': 'medical_imaging',
'authors': ['John Doe', 'Jane Smith']
},
{
'id': 'test_002',
'title': 'Genomics and Machine Learning',
'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.',
'source': 'test',
'domain': 'genomics',
'authors': ['Alan Turing']
}
]
print("πŸ§ͺ Testing FAISS Manager")
print("=" * 50)
try:
manager = FaissManager(
faiss_index_path="./data/test_faiss/index.faiss",
sqlite_db_path="./data/test_faiss/metadata.db",
index_type="Flat" # Use flat for testing (no training needed)
)
# Add test papers
success = manager.add_papers(test_papers)
if success:
print("βœ… Papers added successfully")
# Test search
results = manager.search("medical image analysis", n_results=5)
print(f"πŸ” Search results: {len(results)} chunks found")
for result in results[:2]:
print(f" - {result['metadata']['paper_title']} (distance: {result['distance']:.3f})")
# Get stats
stats = manager.get_collection_stats()
print(f"πŸ“Š Collection stats: {stats}")
else:
print("❌ Failed to add papers")
except Exception as e:
print(f"❌ FAISS test failed: {e}")
if __name__ == "__main__":
test_faiss_manager()