Spaces:
Running
Running
| """Dense retrieval using FAISS vector similarity search. | |
| This module provides the DenseRetriever class for semantic retrieval | |
| using dense embeddings. The retriever combines: | |
| - FAISS index for efficient nearest neighbor search | |
| - BGE encoder for query embedding | |
| - ChunkStore for metadata lookup | |
| The dense retriever is the primary retrieval component, providing | |
| semantic understanding of queries through embedding similarity. | |
| Design Decisions: | |
| - The encoder is lazy loaded to avoid heavy dependencies (torch, | |
| sentence-transformers) when using prebuilt indexes | |
| - Score normalization maps FAISS inner product scores to [0, 1] range | |
| - Missing chunks are handled gracefully with warnings (index/store mismatch) | |
| Lazy Loading: | |
| The BGEEncoder is loaded on first retrieve() call if not provided | |
| in the constructor. This follows the project convention of lazy-loading | |
| heavy dependencies. | |
| Example: | |
| ------- | |
| >>> from pathlib import Path | |
| >>> from rag_chatbot.retrieval import DenseRetriever, FAISSIndex, ChunkStore | |
| >>> # Load index and chunks | |
| >>> index = FAISSIndex.load(Path("data/index.faiss")) | |
| >>> chunks = ChunkStore(Path("data/chunks/chunks.jsonl")) | |
| >>> # Create retriever | |
| >>> retriever = DenseRetriever(faiss_index=index, chunk_store=chunks) | |
| >>> # Retrieve | |
| >>> results = retriever.retrieve("What is PMV?", top_k=5) | |
| >>> for result in results: | |
| ... print(f"{result.chunk_id}: {result.score:.3f}") | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from rag_chatbot.embeddings.encoder import BGEEncoder | |
| from .chunk_store import ChunkStore | |
| from .faiss_index import FAISSIndex | |
| # Import the models directly (lightweight Pydantic models) | |
| from .models import RetrievalResult, normalize_text | |
| # ============================================================================= | |
| # Module Exports | |
| # ============================================================================= | |
| __all__: list[str] = ["DenseRetriever"] | |
| # ============================================================================= | |
| # Logger | |
| # ============================================================================= | |
| logger = logging.getLogger(__name__) | |
| class DenseRetriever: | |
| """Dense retriever using FAISS for semantic similarity search. | |
| This class implements the Retriever protocol using dense vector | |
| embeddings for semantic search. It combines: | |
| - FAISSIndex for efficient nearest neighbor search | |
| - ChunkStore for joining search results with chunk metadata | |
| - BGEEncoder for encoding queries into embedding vectors | |
| The retriever uses inner product similarity (from FAISS IndexFlatIP) | |
| which works well with normalized embeddings like those produced by | |
| BGE models. Scores are normalized to [0, 1] range for consistency | |
| with other retrievers. | |
| Score Normalization: | |
| FAISS inner product scores for normalized vectors are in [-1, 1]. | |
| We normalize to [0, 1] using: normalized = (score + 1) / 2 | |
| This maps: | |
| - Perfect similarity (1.0) -> 1.0 | |
| - Orthogonal (0.0) -> 0.5 | |
| - Opposite (-1.0) -> 0.0 | |
| Lazy Loading Pattern: | |
| The BGEEncoder is loaded lazily on first retrieve() call if not | |
| provided in the constructor. This avoids loading torch and | |
| sentence-transformers until actually needed. | |
| Attributes: | |
| ---------- | |
| faiss_index : FAISSIndex | |
| The FAISS index used for similarity search. | |
| chunk_store : ChunkStore | |
| Store for looking up chunk metadata by ID. | |
| Example: | |
| ------- | |
| >>> # With lazy-loaded encoder | |
| >>> retriever = DenseRetriever(faiss_index=index, chunk_store=chunks) | |
| >>> results = retriever.retrieve("thermal comfort calculation") | |
| >>> # With pre-loaded encoder | |
| >>> encoder = BGEEncoder() | |
| >>> retriever = DenseRetriever( | |
| ... faiss_index=index, | |
| ... chunk_store=chunks, | |
| ... encoder=encoder | |
| ... ) | |
| Note: | |
| ---- | |
| The retriever implements the Retriever protocol defined in | |
| rag_chatbot.retrieval.models, enabling use with HybridRetriever | |
| and dependency injection in tests. | |
| """ | |
| def __init__( | |
| self, | |
| faiss_index: FAISSIndex, | |
| chunk_store: ChunkStore, | |
| encoder: BGEEncoder | None = None, | |
| ) -> None: | |
| """Initialize the dense retriever. | |
| Creates a DenseRetriever with the specified FAISS index and chunk | |
| store. The encoder can be provided or will be lazy loaded on first | |
| retrieve() call. | |
| Args: | |
| ---- | |
| faiss_index: FAISS index for vector similarity search. | |
| Must be a trained/loaded FAISSIndex instance. | |
| chunk_store: Store for chunk metadata lookup. | |
| Must contain chunks matching the indexed chunk_ids. | |
| encoder: Optional BGE encoder for query embedding. | |
| If None, a BGEEncoder will be created lazily on first | |
| retrieve() call. Providing an encoder is useful for: | |
| - Sharing an encoder across multiple retrievers | |
| - Using a custom model or configuration | |
| - Testing with mock encoders | |
| Raises: | |
| ------ | |
| ValueError: If faiss_index or chunk_store is None. | |
| Example: | |
| ------- | |
| >>> # Basic initialization (lazy encoder) | |
| >>> retriever = DenseRetriever( | |
| ... faiss_index=index, | |
| ... chunk_store=chunks | |
| ... ) | |
| >>> # With explicit encoder | |
| >>> encoder = BGEEncoder(device="cpu") | |
| >>> retriever = DenseRetriever( | |
| ... faiss_index=index, | |
| ... chunk_store=chunks, | |
| ... encoder=encoder | |
| ... ) | |
| Note: | |
| ---- | |
| The faiss_index should have been built with embeddings from | |
| the same model as the encoder (default: bge-small-en-v1.5). | |
| """ | |
| # ================================================================= | |
| # Validate required parameters | |
| # ================================================================= | |
| if faiss_index is None: | |
| msg = "faiss_index cannot be None" | |
| raise ValueError(msg) | |
| if chunk_store is None: | |
| msg = "chunk_store cannot be None" | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Store dependencies | |
| # ================================================================= | |
| self._faiss_index: FAISSIndex = faiss_index | |
| self._chunk_store: ChunkStore = chunk_store | |
| # ================================================================= | |
| # Encoder is optional (lazy loaded if not provided) | |
| # Using None as sentinel for lazy initialization | |
| # ================================================================= | |
| self._encoder: BGEEncoder | None = encoder | |
| logger.debug( | |
| "Initialized DenseRetriever with %d indexed vectors", | |
| self._faiss_index.num_vectors, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Private Methods | |
| # ------------------------------------------------------------------------- | |
| def _ensure_encoder_loaded(self) -> BGEEncoder: | |
| """Load the encoder if not already loaded. | |
| This method implements lazy loading for the BGEEncoder. If an | |
| encoder was provided in the constructor, it is returned directly. | |
| Otherwise, a new BGEEncoder is created and cached. | |
| Returns: | |
| ------- | |
| The BGEEncoder instance to use for query encoding. | |
| Note: | |
| ---- | |
| The encoder is cached after first creation, so subsequent | |
| calls return the same instance without reloading. | |
| """ | |
| # Return existing encoder if available | |
| if self._encoder is not None: | |
| return self._encoder | |
| # ================================================================= | |
| # Lazy import and create encoder | |
| # ================================================================= | |
| # Import BGEEncoder here to avoid loading torch and | |
| # sentence-transformers at module import time | |
| # ================================================================= | |
| logger.debug("Lazy loading BGEEncoder for query encoding") | |
| from rag_chatbot.embeddings import BGEEncoder | |
| # Create and cache the encoder | |
| self._encoder = BGEEncoder() | |
| return self._encoder | |
| def _normalize_score(self, raw_score: float) -> float: | |
| """Normalize a FAISS inner product score to [0, 1] range. | |
| FAISS IndexFlatIP returns inner product similarity scores. For | |
| normalized embeddings (like those from BGE models), these scores | |
| are in the range [-1, 1], representing the cosine similarity. | |
| This method normalizes to [0, 1] for consistency with other | |
| retrievers and the RetrievalResult model's score constraints. | |
| Normalization Formula: | |
| normalized = (raw_score + 1) / 2 | |
| Score Mapping: | |
| - raw_score = 1.0 (perfect match) -> 1.0 | |
| - raw_score = 0.0 (orthogonal) -> 0.5 | |
| - raw_score = -1.0 (opposite) -> 0.0 | |
| Args: | |
| ---- | |
| raw_score: Raw inner product score from FAISS search. | |
| Expected range is [-1, 1] for normalized embeddings. | |
| Returns: | |
| ------- | |
| Normalized score in [0, 1] range, clamped to ensure valid range. | |
| Example: | |
| ------- | |
| >>> retriever._normalize_score(0.85) | |
| 0.925 | |
| >>> retriever._normalize_score(-0.5) | |
| 0.25 | |
| Note: | |
| ---- | |
| The result is clamped to [0, 1] to handle any edge cases | |
| where scores might fall slightly outside [-1, 1] due to | |
| numerical precision issues. | |
| """ | |
| # Normalize from [-1, 1] to [0, 1] | |
| # Formula: (x + 1) / 2 maps -1 -> 0, 0 -> 0.5, 1 -> 1 | |
| normalized = (raw_score + 1.0) / 2.0 | |
| # Clamp to [0, 1] to ensure valid range | |
| # This handles numerical precision edge cases | |
| return max(0.0, min(1.0, normalized)) | |
| # ------------------------------------------------------------------------- | |
| # Public Methods (Retriever Protocol) | |
| # ------------------------------------------------------------------------- | |
| def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]: | |
| """Retrieve relevant chunks for a given query. | |
| Encodes the query using BGE embeddings and searches the FAISS | |
| index for the most similar chunks. Results include full chunk | |
| metadata from the chunk store, with scores normalized to [0, 1]. | |
| Processing Steps: | |
| 1. Validate query and top_k parameters | |
| 2. Handle empty index case (return empty list) | |
| 3. Normalize query text to fix OCR/extraction artifacts | |
| 4. Encode query with BGEEncoder (lazy loaded if needed) | |
| 5. Search FAISS index for top_k nearest neighbors | |
| 6. Join results with chunk metadata from store | |
| 7. Normalize scores to [0, 1] range | |
| 8. Return sorted RetrievalResult objects | |
| Args: | |
| ---- | |
| query: The search query string. Will be normalized before | |
| encoding to handle common text issues. | |
| top_k: Maximum number of results to return. Defaults to 6. | |
| Must be a positive integer. If the index contains fewer | |
| vectors, all vectors are returned. | |
| Returns: | |
| ------- | |
| List of RetrievalResult objects sorted by score in descending | |
| order (highest relevance first). Each result contains: | |
| - chunk_id: Unique identifier of the chunk | |
| - text: Full text content of the chunk | |
| - score: Normalized relevance score [0, 1] | |
| - heading_path: Hierarchical heading context | |
| - source: Source document name | |
| - page: Page number in source document | |
| Returns empty list if: | |
| - The index is empty (no vectors) | |
| - The query is empty after normalization | |
| Raises: | |
| ------ | |
| ValueError: If query is empty or top_k is not positive. | |
| RuntimeError: If FAISS search fails. | |
| Example: | |
| ------- | |
| >>> results = retriever.retrieve("What is PMV?", top_k=5) | |
| >>> for result in results: | |
| ... print(f"[{result.score:.3f}] {result.chunk_id}") | |
| ... print(f" Source: {result.source}, Page: {result.page}") | |
| ... print(f" Text: {result.text[:80]}...") | |
| [0.923] ashrae55_042 | |
| Source: ashrae_55.pdf, Page: 15 | |
| Text: The PMV (Predicted Mean Vote) is an index that predicts the mean... | |
| [0.891] iso7730_015 | |
| Source: iso_7730.pdf, Page: 8 | |
| Text: PMV is calculated using the following equation... | |
| Note: | |
| ---- | |
| Chunks that exist in the FAISS index but not in the chunk | |
| store are skipped with a warning log. This can happen if | |
| the index and chunks become out of sync. | |
| """ | |
| # ================================================================= | |
| # Step 1: Validate input parameters | |
| # ================================================================= | |
| if not query: | |
| msg = "query cannot be empty" | |
| raise ValueError(msg) | |
| if not isinstance(top_k, int) or top_k <= 0: | |
| msg = f"top_k must be a positive integer, got {top_k}" | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 2: Handle empty index case | |
| # ================================================================= | |
| # Return empty list if there are no vectors to search | |
| # ================================================================= | |
| if self._faiss_index.num_vectors == 0: | |
| logger.debug("FAISS index is empty, returning no results") | |
| return [] | |
| # ================================================================= | |
| # Step 3: Normalize query text | |
| # ================================================================= | |
| # Apply text normalization to fix common issues from user input | |
| # or copied text (extra spaces, jumbled words, etc.) | |
| # ================================================================= | |
| normalized_query = normalize_text(query) | |
| # Handle case where normalization results in empty string | |
| if not normalized_query: | |
| logger.warning("Query is empty after normalization: %r", query) | |
| return [] | |
| logger.debug( | |
| "Processing query: %r (normalized: %r)", | |
| query, | |
| normalized_query, | |
| ) | |
| # ================================================================= | |
| # Step 4: Encode query with BGEEncoder | |
| # ================================================================= | |
| # Lazy load encoder if not already available | |
| # The encoder handles text normalization internally as well, | |
| # but we pre-normalize for logging purposes | |
| # ================================================================= | |
| encoder = self._ensure_encoder_loaded() | |
| # Encode the query (returns shape (1, embedding_dim)) | |
| # We pass a list with single query and take first embedding | |
| query_embedding = encoder.encode([normalized_query])[0] | |
| logger.debug( | |
| "Encoded query to embedding with shape %s", | |
| query_embedding.shape, | |
| ) | |
| # ================================================================= | |
| # Step 5: Search FAISS index | |
| # ================================================================= | |
| # The search returns list of (chunk_id, score) tuples | |
| # sorted by score descending | |
| # ================================================================= | |
| search_results = self._faiss_index.search( | |
| query_embedding=query_embedding, | |
| top_k=top_k, | |
| ) | |
| logger.debug( | |
| "FAISS search returned %d results", | |
| len(search_results), | |
| ) | |
| # ================================================================= | |
| # Step 6: Join with chunk metadata and build results | |
| # ================================================================= | |
| # Look up each chunk_id in the chunk store to get full metadata | |
| # Skip missing chunks with a warning | |
| # ================================================================= | |
| results: list[RetrievalResult] = [] | |
| for chunk_id, raw_score in search_results: | |
| # Look up chunk metadata | |
| chunk = self._chunk_store.get(chunk_id) | |
| if chunk is None: | |
| # Log warning for missing chunk but continue processing | |
| # This can happen if index and chunks are out of sync | |
| logger.warning( | |
| "Chunk %r found in FAISS index but not in chunk store, skipping", | |
| chunk_id, | |
| ) | |
| continue | |
| # ================================================================= | |
| # Step 7: Normalize score to [0, 1] range | |
| # ================================================================= | |
| normalized_score = self._normalize_score(raw_score) | |
| # ================================================================= | |
| # Step 8: Create RetrievalResult | |
| # ================================================================= | |
| # The RetrievalResult model validates all fields including | |
| # score range constraints | |
| # ================================================================= | |
| try: | |
| result = RetrievalResult( | |
| chunk_id=chunk_id, | |
| text=chunk.text, | |
| score=normalized_score, | |
| heading_path=chunk.heading_path, | |
| source=chunk.source, | |
| page=chunk.page, | |
| ) | |
| results.append(result) | |
| except Exception as e: | |
| # Log validation errors but continue with other results | |
| logger.warning( | |
| "Failed to create RetrievalResult for chunk %r: %s", | |
| chunk_id, | |
| str(e), | |
| ) | |
| continue | |
| # Results are already sorted by score from FAISS search | |
| # Truncate query for logging to avoid excessively long log lines | |
| max_query_log_length = 50 | |
| if len(normalized_query) > max_query_log_length: | |
| query_preview = normalized_query[:max_query_log_length] + "..." | |
| else: | |
| query_preview = normalized_query | |
| logger.info( | |
| "Retrieved %d results for query: %r", | |
| len(results), | |
| query_preview, | |
| ) | |
| return results | |
| # ------------------------------------------------------------------------- | |
| # Properties | |
| # ------------------------------------------------------------------------- | |
| def faiss_index(self) -> FAISSIndex: | |
| """Get the FAISS index used for search. | |
| Returns | |
| ------- | |
| The FAISSIndex instance. | |
| """ | |
| return self._faiss_index | |
| def chunk_store(self) -> ChunkStore: | |
| """Get the chunk store used for metadata lookup. | |
| Returns | |
| ------- | |
| The ChunkStore instance. | |
| """ | |
| return self._chunk_store | |
| def num_indexed(self) -> int: | |
| """Get the number of vectors in the FAISS index. | |
| Returns | |
| ------- | |
| Number of indexed vectors. | |
| """ | |
| return self._faiss_index.num_vectors | |
| def encoder_loaded(self) -> bool: | |
| """Check if the encoder has been loaded. | |
| Returns True if the encoder is available (either provided | |
| in constructor or lazy loaded), False if still pending | |
| lazy initialization. | |
| Returns | |
| ------- | |
| True if encoder is loaded, False otherwise. | |
| """ | |
| return self._encoder is not None | |