Spaces:
Sleeping
Sleeping
| """ | |
| MediGuard AI — FAISS Retriever | |
| Local vector store retriever for development and HuggingFace Spaces. | |
| Uses FAISS for fast similarity search on medical document embeddings. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Any | |
| from src.services.retrieval.interface import BaseRetriever, RetrievalResult | |
| logger = logging.getLogger(__name__) | |
| # Guard import — faiss might not be installed in test environments | |
| try: | |
| from langchain_community.vectorstores import FAISS | |
| except ImportError: | |
| FAISS = None # type: ignore[assignment,misc] | |
| class FAISSRetriever(BaseRetriever): | |
| """ | |
| FAISS-based retriever for local development and HuggingFace deployment. | |
| Supports: | |
| - Semantic similarity search (default) | |
| - Maximal Marginal Relevance (MMR) for diversity | |
| - Score threshold filtering | |
| Does NOT support: | |
| - BM25 keyword search (vector-only) | |
| - Metadata filtering (FAISS limitation) | |
| """ | |
| def __init__( | |
| self, | |
| vector_store: FAISS, | |
| *, | |
| search_type: str = "similarity", # "similarity" or "mmr" | |
| score_threshold: float | None = None, | |
| ): | |
| """ | |
| Initialize FAISS retriever. | |
| Args: | |
| vector_store: Loaded FAISS vector store instance | |
| search_type: "similarity" for cosine, "mmr" for diversity | |
| score_threshold: Minimum score (0-1) to include results | |
| """ | |
| if FAISS is None: | |
| raise ImportError("langchain-community with FAISS is not installed") | |
| self._store = vector_store | |
| self._search_type = search_type | |
| self._score_threshold = score_threshold | |
| self._doc_count_cache: int | None = None | |
| def from_local( | |
| cls, | |
| vector_store_path: str, | |
| embedding_model, | |
| *, | |
| index_name: str = "medical_knowledge", | |
| **kwargs, | |
| ) -> FAISSRetriever: | |
| """ | |
| Load FAISS retriever from a local directory. | |
| Args: | |
| vector_store_path: Directory containing .faiss and .pkl files | |
| embedding_model: Embedding model (must match creation model) | |
| index_name: Name of the index (default: medical_knowledge) | |
| **kwargs: Additional args passed to FAISSRetriever.__init__ | |
| Returns: | |
| Initialized FAISSRetriever | |
| Raises: | |
| FileNotFoundError: If the index doesn't exist | |
| """ | |
| if FAISS is None: | |
| raise ImportError("langchain-community with FAISS is not installed") | |
| store_path = Path(vector_store_path) | |
| index_path = store_path / f"{index_name}.faiss" | |
| if not index_path.exists(): | |
| raise FileNotFoundError(f"FAISS index not found: {index_path}") | |
| logger.info("Loading FAISS index from %s", store_path) | |
| # SECURITY NOTE: allow_dangerous_deserialization=True uses pickle. | |
| # Only load from trusted, locally-built sources. | |
| store = FAISS.load_local( | |
| str(store_path), | |
| embedding_model, | |
| index_name=index_name, | |
| allow_dangerous_deserialization=True, | |
| ) | |
| return cls(store, **kwargs) | |
| def retrieve( | |
| self, | |
| query: str, | |
| *, | |
| top_k: int = 5, | |
| filters: dict[str, Any] | None = None, | |
| ) -> list[RetrievalResult]: | |
| """ | |
| Retrieve documents using FAISS similarity search. | |
| Args: | |
| query: Natural language query | |
| top_k: Maximum number of results | |
| filters: Ignored (FAISS doesn't support metadata filtering) | |
| Returns: | |
| List of RetrievalResult objects | |
| """ | |
| if filters: | |
| logger.warning("FAISS does not support metadata filters; ignoring filters=%s", filters) | |
| try: | |
| if self._search_type == "mmr": | |
| # MMR provides diversity in results | |
| docs_with_scores = self._store.max_marginal_relevance_search_with_score( | |
| query, k=top_k, fetch_k=top_k * 2 | |
| ) | |
| else: | |
| # Standard similarity search | |
| docs_with_scores = self._store.similarity_search_with_score(query, k=top_k) | |
| results = [] | |
| for doc, score in docs_with_scores: | |
| # FAISS returns L2 distance (lower = better), convert to similarity | |
| # Assumes normalized embeddings where L2 distance is in [0, 2] | |
| # Similarity = 1 - (distance / 2), clamped to [0, 1] | |
| similarity = max(0.0, min(1.0, 1 - score / 2)) | |
| # Apply score threshold | |
| if self._score_threshold and similarity < self._score_threshold: | |
| continue | |
| results.append( | |
| RetrievalResult( | |
| doc_id=str(doc.metadata.get("chunk_id", hash(doc.page_content))), | |
| content=doc.page_content, | |
| score=similarity, | |
| metadata=doc.metadata, | |
| ) | |
| ) | |
| logger.debug("FAISS retrieved %d results for query: %s...", len(results), query[:50]) | |
| return results | |
| except Exception as exc: | |
| logger.error("FAISS retrieval failed: %s", exc) | |
| return [] | |
| def health(self) -> bool: | |
| """Check if FAISS store is loaded.""" | |
| return self._store is not None | |
| def doc_count(self) -> int: | |
| """Return number of indexed chunks.""" | |
| if self._doc_count_cache is None: | |
| try: | |
| self._doc_count_cache = self._store.index.ntotal | |
| except Exception: | |
| self._doc_count_cache = 0 | |
| return self._doc_count_cache | |
| def backend_name(self) -> str: | |
| return "FAISS (local)" | |
| # Factory function for quick setup | |
| def make_faiss_retriever( | |
| vector_store_path: str = "data/vector_stores", | |
| embedding_model=None, | |
| index_name: str = "medical_knowledge", | |
| ) -> FAISSRetriever: | |
| """ | |
| Create a FAISS retriever with sensible defaults. | |
| Args: | |
| vector_store_path: Path to vector store directory | |
| embedding_model: Embedding model (auto-loaded if None) | |
| index_name: Index name | |
| Returns: | |
| Configured FAISSRetriever | |
| """ | |
| if embedding_model is None: | |
| from src.llm_config import get_embedding_model | |
| embedding_model = get_embedding_model() | |
| return FAISSRetriever.from_local( | |
| vector_store_path, | |
| embedding_model, | |
| index_name=index_name, | |
| ) | |