Spaces:
Running
Running
| # processing/vector_store.py | |
| """ | |
| Main vector store interface that supports all three database options | |
| Unified API for ChromaDB, FAISS, and Pinecone | |
| """ | |
| from typing import List, Dict, Any, Optional, Union | |
| from processing.chroma_manager import ChromaManager | |
| from processing.faiss_manager import FaissManager | |
| from processing.pinecone_manager import PineconeManager | |
| from config.vector_config import get_vector_config, get_available_vector_types | |
| class VectorStore: | |
| """ | |
| Unified vector store interface supporting multiple database backends | |
| """ | |
| def __init__(self, vector_type: str = "chromadb", **kwargs): | |
| self.vector_type = vector_type | |
| self.manager = None | |
| self._initialize_manager(vector_type, kwargs) | |
| def _initialize_manager(self, vector_type: str, config: Dict[str, Any]): | |
| """Initialize the appropriate vector database manager""" | |
| try: | |
| if vector_type == "chromadb": | |
| self.manager = ChromaManager(**config) | |
| elif vector_type == "faiss_sqlite": | |
| self.manager = FaissManager(**config) | |
| elif vector_type == "pinecone": | |
| self.manager = PineconeManager(**config) | |
| else: | |
| raise ValueError(f"Unknown vector type: {vector_type}") | |
| print(f"✅ Vector store initialized: {vector_type}") | |
| except Exception as e: | |
| print(f"❌ Error initializing {vector_type} manager: {e}") | |
| # Fallback to ChromaDB | |
| if vector_type != "chromadb": | |
| print("🔄 Falling back to ChromaDB...") | |
| self.vector_type = "chromadb" | |
| self._initialize_manager("chromadb", config) | |
| def add_papers(self, papers: List[Dict[str, Any]], **kwargs) -> bool: | |
| """Add papers to vector store""" | |
| if not self.manager: | |
| print("❌ Vector store not initialized") | |
| return False | |
| return self.manager.add_papers(papers, **kwargs) | |
| def search(self, query: str, **kwargs) -> List[Dict[str, Any]]: | |
| """Search for similar paper chunks""" | |
| if not self.manager: | |
| print("❌ Vector store not initialized") | |
| return [] | |
| return self.manager.search(query, **kwargs) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the vector store""" | |
| if not self.manager: | |
| return {"error": "Vector store not initialized"} | |
| stats = self.manager.get_collection_stats() | |
| stats["vector_type"] = self.vector_type | |
| return stats | |
| def delete_paper(self, paper_id: str) -> bool: | |
| """Delete a paper from vector store""" | |
| if not self.manager: | |
| print("❌ Vector store not initialized") | |
| return False | |
| return self.manager.delete_paper(paper_id) | |
| def switch_vector_store(self, new_vector_type: str, **kwargs) -> bool: | |
| """Switch to a different vector store backend""" | |
| try: | |
| print(f"🔄 Switching vector store from {self.vector_type} to {new_vector_type}") | |
| # Initialize new manager | |
| old_manager = self.manager | |
| self.manager = None | |
| self.vector_type = new_vector_type | |
| self._initialize_manager(new_vector_type, kwargs) | |
| if self.manager: | |
| print(f"✅ Successfully switched to {new_vector_type}") | |
| return True | |
| else: | |
| # Restore old manager if switch failed | |
| self.manager = old_manager | |
| self.vector_type = getattr(old_manager, 'vector_type', self.vector_type) | |
| print(f"❌ Failed to switch to {new_vector_type}, restored previous manager") | |
| return False | |
| except Exception as e: | |
| print(f"❌ Error switching vector store: {e}") | |
| return False | |
| class MultiVectorStore: | |
| """ | |
| Manager for multiple vector stores running simultaneously | |
| Useful for performance comparison and redundancy | |
| """ | |
| def __init__(self): | |
| self.stores = {} | |
| self.active_store = None | |
| def add_store(self, name: str, vector_type: str, **kwargs) -> bool: | |
| """Add a vector store instance""" | |
| try: | |
| store = VectorStore(vector_type, **kwargs) | |
| self.stores[name] = store | |
| if not self.active_store: | |
| self.active_store = name | |
| print(f"✅ Added vector store: {name} ({vector_type})") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error adding vector store {name}: {e}") | |
| return False | |
| def set_active_store(self, name: str) -> bool: | |
| """Set the active vector store""" | |
| if name in self.stores: | |
| self.active_store = name | |
| print(f"✅ Active vector store set to: {name}") | |
| return True | |
| else: | |
| print(f"❌ Vector store {name} not found") | |
| return False | |
| def get_active_store(self) -> Optional[VectorStore]: | |
| """Get the active vector store""" | |
| if self.active_store and self.active_store in self.stores: | |
| return self.stores[self.active_store] | |
| return None | |
| def add_papers_to_all(self, papers: List[Dict[str, Any]]) -> Dict[str, bool]: | |
| """Add papers to all vector stores""" | |
| results = {} | |
| for name, store in self.stores.items(): | |
| try: | |
| results[name] = store.add_papers(papers) | |
| except Exception as e: | |
| print(f"❌ Error adding papers to {name}: {e}") | |
| results[name] = False | |
| return results | |
| def search_all(self, query: str, **kwargs) -> Dict[str, List[Dict[str, Any]]]: | |
| """Search all vector stores and compare results""" | |
| results = {} | |
| for name, store in self.stores.items(): | |
| try: | |
| results[name] = store.search(query, **kwargs) | |
| except Exception as e: | |
| print(f"❌ Error searching {name}: {e}") | |
| results[name] = [] | |
| return results | |
| def compare_performance(self, query: str, n_searches: int = 5) -> Dict[str, Any]: | |
| """Compare performance of all vector stores""" | |
| import time | |
| results = {} | |
| test_query = query or "machine learning medical imaging" | |
| for name, store in self.stores.items(): | |
| try: | |
| search_times = [] | |
| for i in range(n_searches): | |
| start_time = time.time() | |
| store.search(test_query, n_results=10) | |
| search_times.append(time.time() - start_time) | |
| avg_time = sum(search_times) / len(search_times) | |
| stats = store.get_stats() | |
| results[name] = { | |
| 'average_search_time': avg_time, | |
| 'total_chunks': stats.get('total_chunks', 0), | |
| 'vector_type': store.vector_type, | |
| 'search_times': search_times | |
| } | |
| except Exception as e: | |
| print(f"❌ Error testing {name}: {e}") | |
| results[name] = {'error': str(e)} | |
| return results | |
| # Quick test and comparison | |
| def test_all_vector_stores(): | |
| """Test and compare all vector store options""" | |
| test_papers = [ | |
| { | |
| 'id': 'comparison_test_001', | |
| 'title': 'Vector Store Comparison Test', | |
| 'abstract': 'This paper tests different vector database implementations for medical research applications.', | |
| 'source': 'test', | |
| 'domain': 'medical_imaging', | |
| 'authors': ['Test Researcher'] | |
| } | |
| ] | |
| print("🧪 Testing All Vector Stores") | |
| print("=" * 60) | |
| multi_store = MultiVectorStore() | |
| # Test ChromaDB | |
| print("\n🔬 Testing ChromaDB...") | |
| chroma_success = multi_store.add_store( | |
| "chromadb_test", | |
| "chromadb", | |
| persist_directory="./data/test_comparison/chromadb" | |
| ) | |
| # Test FAISS | |
| print("\n🔬 Testing FAISS...") | |
| faiss_success = multi_store.add_store( | |
| "faiss_test", | |
| "faiss_sqlite", | |
| faiss_index_path="./data/test_comparison/faiss/index.faiss", | |
| sqlite_db_path="./data/test_comparison/faiss/metadata.db", | |
| index_type="Flat" | |
| ) | |
| # Note: Pinecone requires API key, so we'll skip for basic test | |
| print("\n🔬 Pinecone test skipped (requires API key)") | |
| # Add papers to successful stores | |
| successful_stores = [] | |
| if chroma_success: | |
| successful_stores.append("chromadb_test") | |
| if faiss_success: | |
| successful_stores.append("faiss_test") | |
| if successful_stores: | |
| print(f"\n📦 Adding test papers to: {', '.join(successful_stores)}") | |
| add_results = multi_store.add_papers_to_all(test_papers) | |
| # Test search | |
| print(f"\n🔍 Testing search across all stores...") | |
| search_results = multi_store.search_all("vector database medical research") | |
| for store_name, results in search_results.items(): | |
| print(f" {store_name}: {len(results)} results") | |
| # Compare performance | |
| print(f"\n⚡ Performance comparison...") | |
| performance = multi_store.compare_performance("vector database", n_searches=3) | |
| for store_name, perf_data in performance.items(): | |
| if 'error' not in perf_data: | |
| print( | |
| f" {store_name}: {perf_data['average_search_time']:.3f}s avg, {perf_data['total_chunks']} chunks") | |
| print(f"\n✅ Vector store testing completed!") | |
| if __name__ == "__main__": | |
| test_all_vector_stores() |