Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Vector Store Manager for Maternal Health RAG Chatbot | |
| Uses FAISS with the optimal all-MiniLM-L6-v2 embedding model | |
| """ | |
| import json | |
| import numpy as np | |
| import faiss | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import logging | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| import time | |
| from dataclasses import dataclass | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SearchResult: | |
| """Container for search results""" | |
| content: str | |
| score: float | |
| metadata: Dict[str, Any] | |
| chunk_index: int | |
| source_document: str | |
| chunk_type: str | |
| clinical_importance: float | |
| class MaternalHealthVectorStore: | |
| """Vector store for maternal health guidelines with clinical context filtering""" | |
| def __init__(self, | |
| vector_store_dir: str = "vector_store", | |
| embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", | |
| chunks_dir: str = "comprehensive_chunks"): | |
| self.vector_store_dir = Path(vector_store_dir) | |
| self.vector_store_dir.mkdir(exist_ok=True) | |
| self.chunks_dir = Path(chunks_dir) | |
| self.embedding_model_name = embedding_model | |
| # Initialize components | |
| self.embedding_model = None | |
| self.index = None | |
| self.documents = [] | |
| self.metadata = [] | |
| # Vector store files | |
| self.index_file = self.vector_store_dir / "faiss_index.bin" | |
| self.documents_file = self.vector_store_dir / "documents.json" | |
| self.metadata_file = self.vector_store_dir / "metadata.json" | |
| self.config_file = self.vector_store_dir / "config.json" | |
| # Search parameters | |
| self.default_k = 5 | |
| self.similarity_threshold = 0.3 | |
| def initialize_embedding_model(self): | |
| """Initialize the optimal embedding model""" | |
| logger.info(f"Initializing embedding model: {self.embedding_model_name}") | |
| try: | |
| self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
| logger.info("β Embedding model loaded successfully") | |
| # Get embedding dimension | |
| test_embedding = self.embedding_model.encode(["test"]) | |
| self.embedding_dimension = test_embedding.shape[1] | |
| logger.info(f"π Embedding dimension: {self.embedding_dimension}") | |
| except Exception as e: | |
| logger.error(f"β Failed to load embedding model: {e}") | |
| raise | |
| def load_medical_documents(self) -> List[Dict[str, Any]]: | |
| """Load processed medical documents""" | |
| logger.info("Loading medical documents for vector store...") | |
| langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
| if not langchain_file.exists(): | |
| raise FileNotFoundError(f"Medical documents not found: {langchain_file}") | |
| with open(langchain_file, 'r', encoding='utf-8') as f: | |
| documents = json.load(f) | |
| logger.info(f"π Loaded {len(documents)} medical document chunks") | |
| return documents | |
| def create_vector_index(self, force_rebuild: bool = False) -> bool: | |
| """Create or load FAISS vector index""" | |
| # Check if existing index can be loaded | |
| if not force_rebuild and self.index_file.exists(): | |
| try: | |
| return self.load_existing_index() | |
| except Exception as e: | |
| logger.warning(f"Failed to load existing index: {e}") | |
| logger.info("Rebuilding index from scratch...") | |
| # Initialize embedding model if not done | |
| if self.embedding_model is None: | |
| self.initialize_embedding_model() | |
| # Load documents | |
| documents = self.load_medical_documents() | |
| logger.info("Creating vector embeddings for all medical chunks...") | |
| # Extract content and metadata | |
| texts = [] | |
| metadata = [] | |
| for doc in documents: | |
| content = doc['page_content'] | |
| meta = doc['metadata'] | |
| # Skip very short chunks | |
| if len(content.strip()) < 50: | |
| continue | |
| texts.append(content) | |
| metadata.append(meta) | |
| # Generate embeddings in batches | |
| logger.info(f"Generating embeddings for {len(texts)} chunks...") | |
| start_time = time.time() | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| batch_size=32, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| embed_time = time.time() - start_time | |
| logger.info(f"β‘ Embeddings generated in {embed_time:.2f} seconds") | |
| # Create FAISS index | |
| logger.info("Building FAISS index...") | |
| # Use IndexFlatIP for inner product (cosine similarity) | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| # Create index | |
| index = faiss.IndexFlatIP(self.embedding_dimension) | |
| index.add(embeddings.astype('float32')) | |
| # Store components | |
| self.index = index | |
| self.documents = texts | |
| self.metadata = metadata | |
| # Save to disk | |
| self.save_index() | |
| logger.info(f"β Vector store created with {index.ntotal} embeddings") | |
| return True | |
| def load_existing_index(self) -> bool: | |
| """Load existing FAISS index from disk""" | |
| logger.info("Loading existing vector store...") | |
| try: | |
| # Load FAISS index | |
| self.index = faiss.read_index(str(self.index_file)) | |
| # Load documents | |
| with open(self.documents_file, 'r', encoding='utf-8') as f: | |
| self.documents = json.load(f) | |
| # Load metadata | |
| with open(self.metadata_file, 'r', encoding='utf-8') as f: | |
| self.metadata = json.load(f) | |
| # Load config | |
| with open(self.config_file, 'r') as f: | |
| config = json.load(f) | |
| self.embedding_model_name = config['embedding_model'] | |
| self.embedding_dimension = config['embedding_dimension'] | |
| # Initialize embedding model | |
| self.initialize_embedding_model() | |
| logger.info(f"β Loaded existing vector store with {self.index.ntotal} embeddings") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to load existing index: {e}") | |
| return False | |
| def save_index(self): | |
| """Save FAISS index and metadata to disk""" | |
| logger.info("Saving vector store to disk...") | |
| try: | |
| # Save FAISS index | |
| faiss.write_index(self.index, str(self.index_file)) | |
| # Save documents | |
| with open(self.documents_file, 'w', encoding='utf-8') as f: | |
| json.dump(self.documents, f, ensure_ascii=False, indent=2) | |
| # Save metadata | |
| with open(self.metadata_file, 'w', encoding='utf-8') as f: | |
| json.dump(self.metadata, f, ensure_ascii=False, indent=2) | |
| # Save config | |
| config = { | |
| 'embedding_model': self.embedding_model_name, | |
| 'embedding_dimension': self.embedding_dimension, | |
| 'total_chunks': len(self.documents), | |
| 'created_at': time.strftime('%Y-%m-%d %H:%M:%S') | |
| } | |
| with open(self.config_file, 'w') as f: | |
| json.dump(config, f, indent=2) | |
| logger.info(f"πΎ Vector store saved to {self.vector_store_dir}") | |
| except Exception as e: | |
| logger.error(f"β Failed to save vector store: {e}") | |
| raise | |
| def search(self, | |
| query: str, | |
| k: int = None, | |
| filters: Dict[str, Any] = None, | |
| min_score: float = None) -> List[SearchResult]: | |
| """Search for relevant medical content""" | |
| if self.index is None: | |
| raise ValueError("Vector store not initialized. Call create_vector_index() first.") | |
| if k is None: | |
| k = self.default_k | |
| if min_score is None: | |
| min_score = self.similarity_threshold | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query]) | |
| faiss.normalize_L2(query_embedding) | |
| # Search in FAISS index | |
| scores, indices = self.index.search(query_embedding.astype('float32'), k * 2) # Get more for filtering | |
| # Process results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx == -1 or score < min_score: | |
| continue | |
| # Get document and metadata | |
| content = self.documents[idx] | |
| metadata = self.metadata[idx] | |
| # Apply filters if specified | |
| if filters and not self._matches_filters(metadata, filters): | |
| continue | |
| # Create search result | |
| result = SearchResult( | |
| content=content, | |
| score=float(score), | |
| metadata=metadata, | |
| chunk_index=idx, | |
| source_document=metadata.get('source', ''), | |
| chunk_type=metadata.get('chunk_type', 'text'), | |
| clinical_importance=metadata.get('clinical_importance', 0.5) | |
| ) | |
| results.append(result) | |
| # Stop when we have enough results | |
| if len(results) >= k: | |
| break | |
| return results | |
| def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool: | |
| """Check if metadata matches the specified filters""" | |
| for key, value in filters.items(): | |
| if key not in metadata: | |
| return False | |
| meta_value = metadata[key] | |
| # Handle different filter types | |
| if isinstance(value, list): | |
| if meta_value not in value: | |
| return False | |
| elif isinstance(value, dict): | |
| if 'min' in value and meta_value < value['min']: | |
| return False | |
| if 'max' in value and meta_value > value['max']: | |
| return False | |
| else: | |
| if meta_value != value: | |
| return False | |
| return True | |
| def search_by_medical_context(self, | |
| query: str, | |
| content_types: List[str] = None, | |
| min_importance: float = 0.5, | |
| k: int = 5) -> List[SearchResult]: | |
| """Search with medical context filtering""" | |
| filters = {} | |
| # Filter by content types | |
| if content_types: | |
| filters['chunk_type'] = content_types | |
| # Filter by clinical importance | |
| if min_importance > 0: | |
| filters['clinical_importance'] = {'min': min_importance} | |
| return self.search(query, k=k, filters=filters) | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get vector store statistics""" | |
| if self.index is None: | |
| return {'error': 'Vector store not initialized'} | |
| # Calculate statistics from metadata | |
| chunk_types = {} | |
| importance_distribution = {'low': 0, 'medium': 0, 'high': 0, 'critical': 0} | |
| sources = {} | |
| for meta in self.metadata: | |
| # Chunk types | |
| chunk_type = meta.get('chunk_type', 'unknown') | |
| chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1 | |
| # Importance distribution | |
| importance = meta.get('clinical_importance', 0) | |
| if importance >= 0.9: | |
| importance_distribution['critical'] += 1 | |
| elif importance >= 0.7: | |
| importance_distribution['high'] += 1 | |
| elif importance >= 0.5: | |
| importance_distribution['medium'] += 1 | |
| else: | |
| importance_distribution['low'] += 1 | |
| # Sources | |
| source = meta.get('source', 'unknown') | |
| sources[source] = sources.get(source, 0) + 1 | |
| return { | |
| 'total_chunks': self.index.ntotal, | |
| 'embedding_dimension': self.embedding_dimension, | |
| 'embedding_model': self.embedding_model_name, | |
| 'chunk_type_distribution': chunk_types, | |
| 'clinical_importance_distribution': importance_distribution, | |
| 'source_distribution': dict(list(sources.items())[:10]), # Top 10 sources | |
| 'vector_store_size_mb': self.index_file.stat().st_size / (1024*1024) if self.index_file.exists() else 0 | |
| } | |
| def main(): | |
| """Main function to create and test vector store""" | |
| logger.info("π Creating Maternal Health Vector Store...") | |
| # Create vector store manager | |
| vector_store = MaternalHealthVectorStore() | |
| # Create the vector index | |
| success = vector_store.create_vector_index() | |
| if not success: | |
| logger.error("β Failed to create vector store") | |
| return | |
| # Test searches | |
| logger.info("\nπ Testing search functionality...") | |
| test_queries = [ | |
| "What is the recommended dosage of magnesium sulfate for preeclampsia?", | |
| "How to manage postpartum hemorrhage in emergency situations?", | |
| "Signs and symptoms of puerperal sepsis", | |
| "Normal fetal heart rate during labor" | |
| ] | |
| for query in test_queries: | |
| logger.info(f"\nπ Query: {query}") | |
| results = vector_store.search(query, k=3) | |
| for i, result in enumerate(results, 1): | |
| logger.info(f" {i}. Score: {result.score:.3f} | Type: {result.chunk_type} | " | |
| f"Importance: {result.clinical_importance:.2f}") | |
| logger.info(f" Content: {result.content[:100]}...") | |
| # Get statistics | |
| stats = vector_store.get_statistics() | |
| logger.info("\nπ Vector Store Statistics:") | |
| logger.info(f" Total chunks: {stats['total_chunks']}") | |
| logger.info(f" Embedding dimension: {stats['embedding_dimension']}") | |
| logger.info(f" High importance chunks: {stats['clinical_importance_distribution']['high'] + stats['clinical_importance_distribution']['critical']}") | |
| logger.info(f" Vector store size: {stats['vector_store_size_mb']:.1f} MB") | |
| logger.info("\nβ Vector store creation and testing complete!") | |
| if __name__ == "__main__": | |
| main() |