Spaces:
Running
Running
File size: 6,578 Bytes
3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 9659593 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
MediGuard AI — FAISS Retriever
Local vector store retriever for development and HuggingFace Spaces.
Uses FAISS for fast similarity search on medical document embeddings.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from src.services.retrieval.interface import BaseRetriever, RetrievalResult
logger = logging.getLogger(__name__)
# Guard import — faiss might not be installed in test environments
try:
from langchain_community.vectorstores import FAISS
except ImportError:
FAISS = None # type: ignore[assignment,misc]
class FAISSRetriever(BaseRetriever):
"""
FAISS-based retriever for local development and HuggingFace deployment.
Supports:
- Semantic similarity search (default)
- Maximal Marginal Relevance (MMR) for diversity
- Score threshold filtering
Does NOT support:
- BM25 keyword search (vector-only)
- Metadata filtering (FAISS limitation)
"""
def __init__(
self,
vector_store: FAISS,
*,
search_type: str = "similarity", # "similarity" or "mmr"
score_threshold: float | None = None,
):
"""
Initialize FAISS retriever.
Args:
vector_store: Loaded FAISS vector store instance
search_type: "similarity" for cosine, "mmr" for diversity
score_threshold: Minimum score (0-1) to include results
"""
if FAISS is None:
raise ImportError("langchain-community with FAISS is not installed")
self._store = vector_store
self._search_type = search_type
self._score_threshold = score_threshold
self._doc_count_cache: int | None = None
@classmethod
def from_local(
cls,
vector_store_path: str,
embedding_model,
*,
index_name: str = "medical_knowledge",
**kwargs,
) -> FAISSRetriever:
"""
Load FAISS retriever from a local directory.
Args:
vector_store_path: Directory containing .faiss and .pkl files
embedding_model: Embedding model (must match creation model)
index_name: Name of the index (default: medical_knowledge)
**kwargs: Additional args passed to FAISSRetriever.__init__
Returns:
Initialized FAISSRetriever
Raises:
FileNotFoundError: If the index doesn't exist
"""
if FAISS is None:
raise ImportError("langchain-community with FAISS is not installed")
store_path = Path(vector_store_path)
index_path = store_path / f"{index_name}.faiss"
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found: {index_path}")
logger.info("Loading FAISS index from %s", store_path)
# SECURITY NOTE: allow_dangerous_deserialization=True uses pickle.
# Only load from trusted, locally-built sources.
store = FAISS.load_local(
str(store_path),
embedding_model,
index_name=index_name,
allow_dangerous_deserialization=True,
)
return cls(store, **kwargs)
def retrieve(
self,
query: str,
*,
top_k: int = 5,
filters: dict[str, Any] | None = None,
) -> list[RetrievalResult]:
"""
Retrieve documents using FAISS similarity search.
Args:
query: Natural language query
top_k: Maximum number of results
filters: Ignored (FAISS doesn't support metadata filtering)
Returns:
List of RetrievalResult objects
"""
if filters:
logger.warning("FAISS does not support metadata filters; ignoring filters=%s", filters)
try:
if self._search_type == "mmr":
# MMR provides diversity in results
docs_with_scores = self._store.max_marginal_relevance_search_with_score(
query, k=top_k, fetch_k=top_k * 2
)
else:
# Standard similarity search
docs_with_scores = self._store.similarity_search_with_score(query, k=top_k)
results = []
for doc, score in docs_with_scores:
# FAISS returns L2 distance (lower = better), convert to similarity
# Assumes normalized embeddings where L2 distance is in [0, 2]
# Similarity = 1 - (distance / 2), clamped to [0, 1]
similarity = max(0.0, min(1.0, 1 - score / 2))
# Apply score threshold
if self._score_threshold and similarity < self._score_threshold:
continue
results.append(
RetrievalResult(
doc_id=str(doc.metadata.get("chunk_id", hash(doc.page_content))),
content=doc.page_content,
score=similarity,
metadata=doc.metadata,
)
)
logger.debug("FAISS retrieved %d results for query: %s...", len(results), query[:50])
return results
except Exception as exc:
logger.error("FAISS retrieval failed: %s", exc)
return []
def health(self) -> bool:
"""Check if FAISS store is loaded."""
return self._store is not None
def doc_count(self) -> int:
"""Return number of indexed chunks."""
if self._doc_count_cache is None:
try:
self._doc_count_cache = self._store.index.ntotal
except Exception:
self._doc_count_cache = 0
return self._doc_count_cache
@property
def backend_name(self) -> str:
return "FAISS (local)"
# Factory function for quick setup
def make_faiss_retriever(
vector_store_path: str = "data/vector_stores",
embedding_model=None,
index_name: str = "medical_knowledge",
) -> FAISSRetriever:
"""
Create a FAISS retriever with sensible defaults.
Args:
vector_store_path: Path to vector store directory
embedding_model: Embedding model (auto-loaded if None)
index_name: Index name
Returns:
Configured FAISSRetriever
"""
if embedding_model is None:
from src.llm_config import get_embedding_model
embedding_model = get_embedding_model()
return FAISSRetriever.from_local(
vector_store_path,
embedding_model,
index_name=index_name,
)
|