Agentic-RagBot / src /services /retrieval /faiss_retriever.py
T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
MediGuard AI — FAISS Retriever
Local vector store retriever for development and HuggingFace Spaces.
Uses FAISS for fast similarity search on medical document embeddings.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from src.services.retrieval.interface import BaseRetriever, RetrievalResult
logger = logging.getLogger(__name__)
# Guard import — faiss might not be installed in test environments
try:
from langchain_community.vectorstores import FAISS
except ImportError:
FAISS = None # type: ignore[assignment,misc]
class FAISSRetriever(BaseRetriever):
"""
FAISS-based retriever for local development and HuggingFace deployment.
Supports:
- Semantic similarity search (default)
- Maximal Marginal Relevance (MMR) for diversity
- Score threshold filtering
Does NOT support:
- BM25 keyword search (vector-only)
- Metadata filtering (FAISS limitation)
"""
def __init__(
self,
vector_store: FAISS,
*,
search_type: str = "similarity", # "similarity" or "mmr"
score_threshold: float | None = None,
):
"""
Initialize FAISS retriever.
Args:
vector_store: Loaded FAISS vector store instance
search_type: "similarity" for cosine, "mmr" for diversity
score_threshold: Minimum score (0-1) to include results
"""
if FAISS is None:
raise ImportError("langchain-community with FAISS is not installed")
self._store = vector_store
self._search_type = search_type
self._score_threshold = score_threshold
self._doc_count_cache: int | None = None
@classmethod
def from_local(
cls,
vector_store_path: str,
embedding_model,
*,
index_name: str = "medical_knowledge",
**kwargs,
) -> FAISSRetriever:
"""
Load FAISS retriever from a local directory.
Args:
vector_store_path: Directory containing .faiss and .pkl files
embedding_model: Embedding model (must match creation model)
index_name: Name of the index (default: medical_knowledge)
**kwargs: Additional args passed to FAISSRetriever.__init__
Returns:
Initialized FAISSRetriever
Raises:
FileNotFoundError: If the index doesn't exist
"""
if FAISS is None:
raise ImportError("langchain-community with FAISS is not installed")
store_path = Path(vector_store_path)
index_path = store_path / f"{index_name}.faiss"
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found: {index_path}")
logger.info("Loading FAISS index from %s", store_path)
# SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
# Only load from trusted, locally-built sources.
store = FAISS.load_local(
str(store_path),
embedding_model,
index_name=index_name,
allow_dangerous_deserialization=True,
)
return cls(store, **kwargs)
def retrieve(
self,
query: str,
*,
top_k: int = 5,
filters: dict[str, Any] | None = None,
) -> list[RetrievalResult]:
"""
Retrieve documents using FAISS similarity search.
Args:
query: Natural language query
top_k: Maximum number of results
filters: Ignored (FAISS doesn't support metadata filtering)
Returns:
List of RetrievalResult objects
"""
if filters:
logger.warning("FAISS does not support metadata filters; ignoring filters=%s", filters)
try:
if self._search_type == "mmr":
# MMR provides diversity in results
docs_with_scores = self._store.max_marginal_relevance_search_with_score(
query, k=top_k, fetch_k=top_k * 2
)
else:
# Standard similarity search
docs_with_scores = self._store.similarity_search_with_score(query, k=top_k)
results = []
for doc, score in docs_with_scores:
# FAISS returns L2 distance (lower = better), convert to similarity
# Assumes normalized embeddings where L2 distance is in [0, 2]
# Similarity = 1 - (distance / 2), clamped to [0, 1]
similarity = max(0.0, min(1.0, 1 - score / 2))
# Apply score threshold
if self._score_threshold and similarity < self._score_threshold:
continue
results.append(
RetrievalResult(
doc_id=str(doc.metadata.get("chunk_id", hash(doc.page_content))),
content=doc.page_content,
score=similarity,
metadata=doc.metadata,
)
)
logger.debug("FAISS retrieved %d results for query: %s...", len(results), query[:50])
return results
except Exception as exc:
logger.error("FAISS retrieval failed: %s", exc)
return []
def health(self) -> bool:
"""Check if FAISS store is loaded."""
return self._store is not None
def doc_count(self) -> int:
"""Return number of indexed chunks."""
if self._doc_count_cache is None:
try:
self._doc_count_cache = self._store.index.ntotal
except Exception:
self._doc_count_cache = 0
return self._doc_count_cache
@property
def backend_name(self) -> str:
return "FAISS (local)"
# Factory function for quick setup
def make_faiss_retriever(
vector_store_path: str = "data/vector_stores",
embedding_model=None,
index_name: str = "medical_knowledge",
) -> FAISSRetriever:
"""
Create a FAISS retriever with sensible defaults.
Args:
vector_store_path: Path to vector store directory
embedding_model: Embedding model (auto-loaded if None)
index_name: Index name
Returns:
Configured FAISSRetriever
"""
if embedding_model is None:
from src.llm_config import get_embedding_model
embedding_model = get_embedding_model()
return FAISSRetriever.from_local(
vector_store_path,
embedding_model,
index_name=index_name,
)