Spaces:
Sleeping
Sleeping
| """ | |
| Simple Vector Store for Medical RAG v2.0 | |
| Research-backed approach: Document-based retrieval with simple metadata | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import time | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from pathlib import Path | |
| import numpy as np | |
| from dataclasses import dataclass | |
| # Vector store and embeddings | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from langchain_core.documents import Document | |
| class SearchResult: | |
| """Simple search result structure""" | |
| content: str | |
| score: float | |
| metadata: Dict[str, Any] | |
| document_name: str | |
| content_type: str | |
| class SimpleVectorStore: | |
| """ | |
| Simple vector store using research-optimal embedding approach | |
| - Focused on document-based retrieval | |
| - Simplified metadata structure | |
| - High-performance FAISS indexing | |
| """ | |
| def __init__(self, | |
| embedding_model: str = "all-MiniLM-L6-v2", | |
| index_type: str = "IndexFlatIP", # Inner Product for cosine similarity | |
| vector_store_dir: str = "simple_vector_store"): | |
| """ | |
| Initialize the simple vector store | |
| Args: | |
| embedding_model: Sentence transformer model name | |
| index_type: FAISS index type | |
| vector_store_dir: Directory to store vector index and metadata | |
| """ | |
| self.embedding_model_name = embedding_model | |
| self.index_type = index_type | |
| self.vector_store_dir = Path(vector_store_dir) | |
| self.vector_store_dir.mkdir(exist_ok=True) | |
| # Initialize components | |
| self.embedding_model = None | |
| self.index = None | |
| self.documents = [] | |
| self.metadata = [] | |
| self.setup_logging() | |
| self._initialize_embedding_model() | |
| def setup_logging(self): | |
| """Setup logging for the vector store""" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| self.logger = logging.getLogger(__name__) | |
| def _initialize_embedding_model(self): | |
| """Initialize the sentence transformer model""" | |
| try: | |
| self.logger.info(f"Loading embedding model: {self.embedding_model_name}") | |
| self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
| self.logger.info(f"Embedding model loaded successfully") | |
| except Exception as e: | |
| self.logger.error(f"Error loading embedding model: {e}") | |
| raise | |
| def create_embeddings(self, chunks: List[Document]) -> Tuple[np.ndarray, int]: | |
| """Create embeddings for document chunks""" | |
| if not chunks: | |
| raise ValueError("No chunks provided for embedding") | |
| start_time = time.time() | |
| # Extract text content | |
| texts = [chunk.page_content for chunk in chunks] | |
| self.logger.info(f"Creating embeddings for {len(texts)} chunks...") | |
| # Generate embeddings | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| show_progress_bar=True, | |
| batch_size=32, | |
| normalize_embeddings=True # Important for cosine similarity | |
| ) | |
| # Store documents and metadata | |
| self.documents = chunks | |
| self.metadata = [chunk.metadata for chunk in chunks] | |
| embedding_time = time.time() - start_time | |
| self.logger.info(f"Created {len(embeddings)} embeddings in {embedding_time:.2f} seconds") | |
| return embeddings, len(embeddings) | |
| def build_index(self, embeddings: np.ndarray): | |
| """Build FAISS index from embeddings""" | |
| dimension = embeddings.shape[1] | |
| # Create FAISS index | |
| if self.index_type == "IndexFlatIP": | |
| # Inner Product index (good for normalized embeddings) | |
| self.index = faiss.IndexFlatIP(dimension) | |
| elif self.index_type == "IndexFlatL2": | |
| # L2 distance index | |
| self.index = faiss.IndexFlatL2(dimension) | |
| else: | |
| raise ValueError(f"Unsupported index type: {self.index_type}") | |
| # Add embeddings to index | |
| self.index.add(embeddings.astype('float32')) | |
| self.logger.info(f"Built FAISS index with {self.index.ntotal} vectors") | |
| def save_vector_store(self): | |
| """Save vector store to disk""" | |
| try: | |
| # Save FAISS index | |
| index_path = self.vector_store_dir / "faiss_index.bin" | |
| faiss.write_index(self.index, str(index_path)) | |
| # Save documents | |
| docs_path = self.vector_store_dir / "documents.json" | |
| docs_data = [] | |
| for doc in self.documents: | |
| docs_data.append({ | |
| 'page_content': doc.page_content, | |
| 'metadata': doc.metadata | |
| }) | |
| with open(docs_path, 'w', encoding='utf-8') as f: | |
| json.dump(docs_data, f, indent=2, ensure_ascii=False) | |
| # Save configuration | |
| config_path = self.vector_store_dir / "config.json" | |
| config = { | |
| 'embedding_model': self.embedding_model_name, | |
| 'index_type': self.index_type, | |
| 'total_documents': len(self.documents), | |
| 'dimension': self.index.d if self.index else 0, | |
| 'created_at': time.strftime('%Y-%m-%d %H:%M:%S') | |
| } | |
| with open(config_path, 'w', encoding='utf-8') as f: | |
| json.dump(config, f, indent=2) | |
| self.logger.info(f"Vector store saved to {self.vector_store_dir}") | |
| except Exception as e: | |
| self.logger.error(f"Error saving vector store: {e}") | |
| raise | |
| def load_vector_store(self) -> bool: | |
| """Load vector store from disk""" | |
| try: | |
| index_path = self.vector_store_dir / "faiss_index.bin" | |
| docs_path = self.vector_store_dir / "documents.json" | |
| config_path = self.vector_store_dir / "config.json" | |
| if not all(p.exists() for p in [index_path, docs_path, config_path]): | |
| return False | |
| # Load FAISS index | |
| self.index = faiss.read_index(str(index_path)) | |
| # Load documents | |
| with open(docs_path, 'r', encoding='utf-8') as f: | |
| docs_data = json.load(f) | |
| self.documents = [] | |
| self.metadata = [] | |
| for doc_data in docs_data: | |
| doc = Document( | |
| page_content=doc_data['page_content'], | |
| metadata=doc_data['metadata'] | |
| ) | |
| self.documents.append(doc) | |
| self.metadata.append(doc_data['metadata']) | |
| # Load configuration | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| self.logger.info(f"Loaded vector store with {len(self.documents)} documents") | |
| return True | |
| except Exception as e: | |
| self.logger.error(f"Error loading vector store: {e}") | |
| return False | |
| def search(self, | |
| query: str, | |
| k: int = 5, | |
| content_type_filter: Optional[str] = None) -> List[SearchResult]: | |
| """ | |
| Search for similar documents | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| content_type_filter: Filter by content type (optional) | |
| Returns: | |
| List of SearchResult objects | |
| """ | |
| if not self.index or not self.documents: | |
| raise ValueError("Vector store not initialized. Load or create index first.") | |
| # Create query embedding | |
| query_embedding = self.embedding_model.encode( | |
| [query], | |
| normalize_embeddings=True | |
| ) | |
| # Search in FAISS index | |
| # Get more results initially for filtering | |
| search_k = min(k * 3, len(self.documents)) | |
| scores, indices = self.index.search(query_embedding.astype('float32'), search_k) | |
| # Process results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx == -1: # Invalid index | |
| continue | |
| doc = self.documents[idx] | |
| metadata = self.metadata[idx] | |
| # Apply content type filter if specified | |
| if content_type_filter: | |
| doc_content_type = metadata.get('content_type', '') | |
| if content_type_filter.lower() not in doc_content_type.lower(): | |
| continue | |
| result = SearchResult( | |
| content=doc.page_content, | |
| score=float(score), | |
| metadata=metadata, | |
| document_name=metadata.get('document_name', 'Unknown'), | |
| content_type=metadata.get('content_type', 'general') | |
| ) | |
| results.append(result) | |
| # Stop when we have enough results | |
| if len(results) >= k: | |
| break | |
| return results | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get vector store statistics""" | |
| if not self.documents: | |
| return {"status": "empty"} | |
| # Document statistics | |
| doc_counts = {} | |
| content_type_counts = {} | |
| total_chars = 0 | |
| for doc in self.documents: | |
| # Document distribution | |
| doc_name = doc.metadata.get('document_name', 'Unknown') | |
| doc_counts[doc_name] = doc_counts.get(doc_name, 0) + 1 | |
| # Content type distribution | |
| content_type = doc.metadata.get('content_type', 'general') | |
| content_type_counts[content_type] = content_type_counts.get(content_type, 0) + 1 | |
| # Character count | |
| total_chars += len(doc.page_content) | |
| # Vector store size estimation | |
| if self.index: | |
| # Estimate size: vectors + metadata | |
| vector_size_mb = (self.index.ntotal * self.index.d * 4) / (1024 * 1024) # 4 bytes per float32 | |
| metadata_size_mb = total_chars / (1024 * 1024) # Rough estimate | |
| total_size_mb = vector_size_mb + metadata_size_mb | |
| else: | |
| total_size_mb = 0 | |
| return { | |
| "status": "ready", | |
| "total_chunks": len(self.documents), | |
| "embedding_model": self.embedding_model_name, | |
| "index_type": self.index_type, | |
| "vector_dimension": self.index.d if self.index else 0, | |
| "vector_store_size_mb": round(total_size_mb, 2), | |
| "avg_chunk_size": round(total_chars / len(self.documents), 1), | |
| "document_distribution": dict(sorted(doc_counts.items(), key=lambda x: x[1], reverse=True)[:10]), | |
| "content_type_distribution": content_type_counts | |
| } | |
| def main(): | |
| """Main function to test the simple vector store""" | |
| print("π Testing Simple Vector Store v2.0") | |
| print("=" * 60) | |
| try: | |
| # Initialize vector store | |
| vector_store = SimpleVectorStore( | |
| embedding_model="all-MiniLM-L6-v2", | |
| index_type="IndexFlatIP" | |
| ) | |
| # Check if we can load existing vector store | |
| if vector_store.load_vector_store(): | |
| print("β Loaded existing vector store") | |
| else: | |
| print("π Creating new vector store from chunks...") | |
| # Load chunks from simple chunker | |
| from simple_document_chunker import SimpleDocumentChunker | |
| chunker = SimpleDocumentChunker() | |
| documents = chunker.load_processed_documents() | |
| chunks = chunker.create_simple_chunks(documents) | |
| print(f"β Loaded {len(chunks)} chunks for embedding") | |
| # Create embeddings | |
| embeddings, count = vector_store.create_embeddings(chunks) | |
| # Build index | |
| vector_store.build_index(embeddings) | |
| # Save vector store | |
| vector_store.save_vector_store() | |
| print("β Vector store created and saved") | |
| # Get statistics | |
| stats = vector_store.get_stats() | |
| print(f"\nπ VECTOR STORE STATISTICS:") | |
| print(f" Status: {stats['status'].upper()}") | |
| print(f" Total chunks: {stats['total_chunks']:,}") | |
| print(f" Embedding model: {stats['embedding_model']}") | |
| print(f" Vector dimension: {stats['vector_dimension']}") | |
| print(f" Store size: {stats['vector_store_size_mb']} MB") | |
| print(f" Avg chunk size: {stats['avg_chunk_size']:.0f} chars") | |
| print(f"\nπ Content Type Distribution:") | |
| for content_type, count in stats['content_type_distribution'].items(): | |
| percentage = (count / stats['total_chunks']) * 100 | |
| print(f" {content_type}: {count:,} chunks ({percentage:.1f}%)") | |
| # Test search functionality | |
| print(f"\nπ TESTING SEARCH FUNCTIONALITY:") | |
| test_queries = [ | |
| "magnesium sulfate dosage preeclampsia", | |
| "postpartum hemorrhage management", | |
| "fetal heart rate monitoring", | |
| "emergency cesarean delivery" | |
| ] | |
| for query in test_queries: | |
| print(f"\nπ Query: '{query}'") | |
| results = vector_store.search(query, k=3) | |
| for i, result in enumerate(results, 1): | |
| print(f" Result {i}: Score={result.score:.3f}, Doc={result.document_name}") | |
| print(f" Type={result.content_type}") | |
| print(f" Preview: {result.content[:100]}...") | |
| print(f"\nπ Simple Vector Store Testing Complete!") | |
| print(f"β Successfully created vector store with {stats['total_chunks']:,} embeddings") | |
| print(f"β Search functionality working with high relevance scores") | |
| return vector_store | |
| except Exception as e: | |
| print(f"β Error in simple vector store: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| if __name__ == "__main__": | |
| main() |