sadickam's picture
Prepare for HF Space deployment
d01a7e3
"""Dense retrieval using FAISS vector similarity search.
This module provides the DenseRetriever class for semantic retrieval
using dense embeddings. The retriever combines:
- FAISS index for efficient nearest neighbor search
- BGE encoder for query embedding
- ChunkStore for metadata lookup
The dense retriever is the primary retrieval component, providing
semantic understanding of queries through embedding similarity.
Design Decisions:
- The encoder is lazy loaded to avoid heavy dependencies (torch,
sentence-transformers) when using prebuilt indexes
- Score normalization maps FAISS inner product scores to [0, 1] range
- Missing chunks are handled gracefully with warnings (index/store mismatch)
Lazy Loading:
The BGEEncoder is loaded on first retrieve() call if not provided
in the constructor. This follows the project convention of lazy-loading
heavy dependencies.
Example:
-------
>>> from pathlib import Path
>>> from rag_chatbot.retrieval import DenseRetriever, FAISSIndex, ChunkStore
>>> # Load index and chunks
>>> index = FAISSIndex.load(Path("data/index.faiss"))
>>> chunks = ChunkStore(Path("data/chunks/chunks.jsonl"))
>>> # Create retriever
>>> retriever = DenseRetriever(faiss_index=index, chunk_store=chunks)
>>> # Retrieve
>>> results = retriever.retrieve("What is PMV?", top_k=5)
>>> for result in results:
... print(f"{result.chunk_id}: {result.score:.3f}")
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from rag_chatbot.embeddings.encoder import BGEEncoder
from .chunk_store import ChunkStore
from .faiss_index import FAISSIndex
# Import the models directly (lightweight Pydantic models)
from .models import RetrievalResult, normalize_text
# =============================================================================
# Module Exports
# =============================================================================
__all__: list[str] = ["DenseRetriever"]
# =============================================================================
# Logger
# =============================================================================
logger = logging.getLogger(__name__)
class DenseRetriever:
"""Dense retriever using FAISS for semantic similarity search.
This class implements the Retriever protocol using dense vector
embeddings for semantic search. It combines:
- FAISSIndex for efficient nearest neighbor search
- ChunkStore for joining search results with chunk metadata
- BGEEncoder for encoding queries into embedding vectors
The retriever uses inner product similarity (from FAISS IndexFlatIP)
which works well with normalized embeddings like those produced by
BGE models. Scores are normalized to [0, 1] range for consistency
with other retrievers.
Score Normalization:
FAISS inner product scores for normalized vectors are in [-1, 1].
We normalize to [0, 1] using: normalized = (score + 1) / 2
This maps:
- Perfect similarity (1.0) -> 1.0
- Orthogonal (0.0) -> 0.5
- Opposite (-1.0) -> 0.0
Lazy Loading Pattern:
The BGEEncoder is loaded lazily on first retrieve() call if not
provided in the constructor. This avoids loading torch and
sentence-transformers until actually needed.
Attributes:
----------
faiss_index : FAISSIndex
The FAISS index used for similarity search.
chunk_store : ChunkStore
Store for looking up chunk metadata by ID.
Example:
-------
>>> # With lazy-loaded encoder
>>> retriever = DenseRetriever(faiss_index=index, chunk_store=chunks)
>>> results = retriever.retrieve("thermal comfort calculation")
>>> # With pre-loaded encoder
>>> encoder = BGEEncoder()
>>> retriever = DenseRetriever(
... faiss_index=index,
... chunk_store=chunks,
... encoder=encoder
... )
Note:
----
The retriever implements the Retriever protocol defined in
rag_chatbot.retrieval.models, enabling use with HybridRetriever
and dependency injection in tests.
"""
def __init__(
self,
faiss_index: FAISSIndex,
chunk_store: ChunkStore,
encoder: BGEEncoder | None = None,
) -> None:
"""Initialize the dense retriever.
Creates a DenseRetriever with the specified FAISS index and chunk
store. The encoder can be provided or will be lazy loaded on first
retrieve() call.
Args:
----
faiss_index: FAISS index for vector similarity search.
Must be a trained/loaded FAISSIndex instance.
chunk_store: Store for chunk metadata lookup.
Must contain chunks matching the indexed chunk_ids.
encoder: Optional BGE encoder for query embedding.
If None, a BGEEncoder will be created lazily on first
retrieve() call. Providing an encoder is useful for:
- Sharing an encoder across multiple retrievers
- Using a custom model or configuration
- Testing with mock encoders
Raises:
------
ValueError: If faiss_index or chunk_store is None.
Example:
-------
>>> # Basic initialization (lazy encoder)
>>> retriever = DenseRetriever(
... faiss_index=index,
... chunk_store=chunks
... )
>>> # With explicit encoder
>>> encoder = BGEEncoder(device="cpu")
>>> retriever = DenseRetriever(
... faiss_index=index,
... chunk_store=chunks,
... encoder=encoder
... )
Note:
----
The faiss_index should have been built with embeddings from
the same model as the encoder (default: bge-small-en-v1.5).
"""
# =================================================================
# Validate required parameters
# =================================================================
if faiss_index is None:
msg = "faiss_index cannot be None"
raise ValueError(msg)
if chunk_store is None:
msg = "chunk_store cannot be None"
raise ValueError(msg)
# =================================================================
# Store dependencies
# =================================================================
self._faiss_index: FAISSIndex = faiss_index
self._chunk_store: ChunkStore = chunk_store
# =================================================================
# Encoder is optional (lazy loaded if not provided)
# Using None as sentinel for lazy initialization
# =================================================================
self._encoder: BGEEncoder | None = encoder
logger.debug(
"Initialized DenseRetriever with %d indexed vectors",
self._faiss_index.num_vectors,
)
# -------------------------------------------------------------------------
# Private Methods
# -------------------------------------------------------------------------
def _ensure_encoder_loaded(self) -> BGEEncoder:
"""Load the encoder if not already loaded.
This method implements lazy loading for the BGEEncoder. If an
encoder was provided in the constructor, it is returned directly.
Otherwise, a new BGEEncoder is created and cached.
Returns:
-------
The BGEEncoder instance to use for query encoding.
Note:
----
The encoder is cached after first creation, so subsequent
calls return the same instance without reloading.
"""
# Return existing encoder if available
if self._encoder is not None:
return self._encoder
# =================================================================
# Lazy import and create encoder
# =================================================================
# Import BGEEncoder here to avoid loading torch and
# sentence-transformers at module import time
# =================================================================
logger.debug("Lazy loading BGEEncoder for query encoding")
from rag_chatbot.embeddings import BGEEncoder
# Create and cache the encoder
self._encoder = BGEEncoder()
return self._encoder
def _normalize_score(self, raw_score: float) -> float:
"""Normalize a FAISS inner product score to [0, 1] range.
FAISS IndexFlatIP returns inner product similarity scores. For
normalized embeddings (like those from BGE models), these scores
are in the range [-1, 1], representing the cosine similarity.
This method normalizes to [0, 1] for consistency with other
retrievers and the RetrievalResult model's score constraints.
Normalization Formula:
normalized = (raw_score + 1) / 2
Score Mapping:
- raw_score = 1.0 (perfect match) -> 1.0
- raw_score = 0.0 (orthogonal) -> 0.5
- raw_score = -1.0 (opposite) -> 0.0
Args:
----
raw_score: Raw inner product score from FAISS search.
Expected range is [-1, 1] for normalized embeddings.
Returns:
-------
Normalized score in [0, 1] range, clamped to ensure valid range.
Example:
-------
>>> retriever._normalize_score(0.85)
0.925
>>> retriever._normalize_score(-0.5)
0.25
Note:
----
The result is clamped to [0, 1] to handle any edge cases
where scores might fall slightly outside [-1, 1] due to
numerical precision issues.
"""
# Normalize from [-1, 1] to [0, 1]
# Formula: (x + 1) / 2 maps -1 -> 0, 0 -> 0.5, 1 -> 1
normalized = (raw_score + 1.0) / 2.0
# Clamp to [0, 1] to ensure valid range
# This handles numerical precision edge cases
return max(0.0, min(1.0, normalized))
# -------------------------------------------------------------------------
# Public Methods (Retriever Protocol)
# -------------------------------------------------------------------------
def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]:
"""Retrieve relevant chunks for a given query.
Encodes the query using BGE embeddings and searches the FAISS
index for the most similar chunks. Results include full chunk
metadata from the chunk store, with scores normalized to [0, 1].
Processing Steps:
1. Validate query and top_k parameters
2. Handle empty index case (return empty list)
3. Normalize query text to fix OCR/extraction artifacts
4. Encode query with BGEEncoder (lazy loaded if needed)
5. Search FAISS index for top_k nearest neighbors
6. Join results with chunk metadata from store
7. Normalize scores to [0, 1] range
8. Return sorted RetrievalResult objects
Args:
----
query: The search query string. Will be normalized before
encoding to handle common text issues.
top_k: Maximum number of results to return. Defaults to 6.
Must be a positive integer. If the index contains fewer
vectors, all vectors are returned.
Returns:
-------
List of RetrievalResult objects sorted by score in descending
order (highest relevance first). Each result contains:
- chunk_id: Unique identifier of the chunk
- text: Full text content of the chunk
- score: Normalized relevance score [0, 1]
- heading_path: Hierarchical heading context
- source: Source document name
- page: Page number in source document
Returns empty list if:
- The index is empty (no vectors)
- The query is empty after normalization
Raises:
------
ValueError: If query is empty or top_k is not positive.
RuntimeError: If FAISS search fails.
Example:
-------
>>> results = retriever.retrieve("What is PMV?", top_k=5)
>>> for result in results:
... print(f"[{result.score:.3f}] {result.chunk_id}")
... print(f" Source: {result.source}, Page: {result.page}")
... print(f" Text: {result.text[:80]}...")
[0.923] ashrae55_042
Source: ashrae_55.pdf, Page: 15
Text: The PMV (Predicted Mean Vote) is an index that predicts the mean...
[0.891] iso7730_015
Source: iso_7730.pdf, Page: 8
Text: PMV is calculated using the following equation...
Note:
----
Chunks that exist in the FAISS index but not in the chunk
store are skipped with a warning log. This can happen if
the index and chunks become out of sync.
"""
# =================================================================
# Step 1: Validate input parameters
# =================================================================
if not query:
msg = "query cannot be empty"
raise ValueError(msg)
if not isinstance(top_k, int) or top_k <= 0:
msg = f"top_k must be a positive integer, got {top_k}"
raise ValueError(msg)
# =================================================================
# Step 2: Handle empty index case
# =================================================================
# Return empty list if there are no vectors to search
# =================================================================
if self._faiss_index.num_vectors == 0:
logger.debug("FAISS index is empty, returning no results")
return []
# =================================================================
# Step 3: Normalize query text
# =================================================================
# Apply text normalization to fix common issues from user input
# or copied text (extra spaces, jumbled words, etc.)
# =================================================================
normalized_query = normalize_text(query)
# Handle case where normalization results in empty string
if not normalized_query:
logger.warning("Query is empty after normalization: %r", query)
return []
logger.debug(
"Processing query: %r (normalized: %r)",
query,
normalized_query,
)
# =================================================================
# Step 4: Encode query with BGEEncoder
# =================================================================
# Lazy load encoder if not already available
# The encoder handles text normalization internally as well,
# but we pre-normalize for logging purposes
# =================================================================
encoder = self._ensure_encoder_loaded()
# Encode the query (returns shape (1, embedding_dim))
# We pass a list with single query and take first embedding
query_embedding = encoder.encode([normalized_query])[0]
logger.debug(
"Encoded query to embedding with shape %s",
query_embedding.shape,
)
# =================================================================
# Step 5: Search FAISS index
# =================================================================
# The search returns list of (chunk_id, score) tuples
# sorted by score descending
# =================================================================
search_results = self._faiss_index.search(
query_embedding=query_embedding,
top_k=top_k,
)
logger.debug(
"FAISS search returned %d results",
len(search_results),
)
# =================================================================
# Step 6: Join with chunk metadata and build results
# =================================================================
# Look up each chunk_id in the chunk store to get full metadata
# Skip missing chunks with a warning
# =================================================================
results: list[RetrievalResult] = []
for chunk_id, raw_score in search_results:
# Look up chunk metadata
chunk = self._chunk_store.get(chunk_id)
if chunk is None:
# Log warning for missing chunk but continue processing
# This can happen if index and chunks are out of sync
logger.warning(
"Chunk %r found in FAISS index but not in chunk store, skipping",
chunk_id,
)
continue
# =================================================================
# Step 7: Normalize score to [0, 1] range
# =================================================================
normalized_score = self._normalize_score(raw_score)
# =================================================================
# Step 8: Create RetrievalResult
# =================================================================
# The RetrievalResult model validates all fields including
# score range constraints
# =================================================================
try:
result = RetrievalResult(
chunk_id=chunk_id,
text=chunk.text,
score=normalized_score,
heading_path=chunk.heading_path,
source=chunk.source,
page=chunk.page,
)
results.append(result)
except Exception as e:
# Log validation errors but continue with other results
logger.warning(
"Failed to create RetrievalResult for chunk %r: %s",
chunk_id,
str(e),
)
continue
# Results are already sorted by score from FAISS search
# Truncate query for logging to avoid excessively long log lines
max_query_log_length = 50
if len(normalized_query) > max_query_log_length:
query_preview = normalized_query[:max_query_log_length] + "..."
else:
query_preview = normalized_query
logger.info(
"Retrieved %d results for query: %r",
len(results),
query_preview,
)
return results
# -------------------------------------------------------------------------
# Properties
# -------------------------------------------------------------------------
@property
def faiss_index(self) -> FAISSIndex:
"""Get the FAISS index used for search.
Returns
-------
The FAISSIndex instance.
"""
return self._faiss_index
@property
def chunk_store(self) -> ChunkStore:
"""Get the chunk store used for metadata lookup.
Returns
-------
The ChunkStore instance.
"""
return self._chunk_store
@property
def num_indexed(self) -> int:
"""Get the number of vectors in the FAISS index.
Returns
-------
Number of indexed vectors.
"""
return self._faiss_index.num_vectors
@property
def encoder_loaded(self) -> bool:
"""Check if the encoder has been loaded.
Returns True if the encoder is available (either provided
in constructor or lazy loaded), False if still pending
lazy initialization.
Returns
-------
True if encoder is loaded, False otherwise.
"""
return self._encoder is not None