Spaces:
Running
Running
| """BM25 sparse retrieval for keyword-based search. | |
| This module provides the BM25Retriever class for sparse retrieval using the | |
| BM25 (Best Match 25) algorithm. BM25 is a probabilistic ranking function that | |
| scores documents based on term frequency (TF) and inverse document frequency (IDF). | |
| BM25 is particularly effective for: | |
| - Keyword matching: Exact term retrieval where semantic similarity may fail | |
| - Out-of-vocabulary terms: Technical terms or acronyms not in embedding vocab | |
| - Hybrid retrieval: Complementing dense embeddings with sparse signals | |
| The BM25 scoring formula is: | |
| score(D,Q) = sum_{i=1}^{n} IDF(q_i) * (f(q_i,D) * (k1+1)) / | |
| (f(q_i,D) + k1 * (1-b+b*|D|/avgdl)) | |
| Where: | |
| - f(q_i,D) = term frequency of query term q_i in document D | |
| - |D| = length of document D in words | |
| - avgdl = average document length across the corpus | |
| - k1 = term frequency saturation parameter (default: 1.5) | |
| - b = document length normalization parameter (default: 0.75) | |
| - IDF(q_i) = log((N - n(q_i) + 0.5) / (n(q_i) + 0.5)) | |
| - N = total number of documents | |
| - n(q_i) = number of documents containing term q_i | |
| Design Decisions: | |
| - Lazy loading: rank_bm25 is imported on first use to avoid overhead | |
| - Text normalization: Uses normalize_text from models.py plus tokenization | |
| - Score normalization: Raw BM25 scores normalized to [0, 1] using min-max | |
| - Persistence: Index saved via pickle with tokenized corpus (not BM25 object) | |
| Lazy Loading: | |
| The rank_bm25 library is loaded on first use (build or load) to avoid | |
| import overhead when BM25 retrieval is not needed. This follows the | |
| project convention for heavy dependencies. | |
| Example: | |
| ------- | |
| >>> from rag_chatbot.retrieval import BM25Retriever | |
| >>> # Build index from corpus | |
| >>> retriever = BM25Retriever(k1=1.5, b=0.75) | |
| >>> retriever.build(corpus=["doc1 text", "doc2 text"], chunk_ids=["c1", "c2"]) | |
| >>> # Retrieve | |
| >>> results = retriever.retrieve("search query", top_k=5) | |
| >>> for chunk_id, score in results: | |
| ... print(f"{chunk_id}: {score:.3f}") | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import pickle | |
| import re | |
| import string | |
| from pathlib import Path | |
| from types import ModuleType | |
| from typing import TYPE_CHECKING, Any | |
| # Import the normalize_text function from models (lightweight, no heavy deps) | |
| from .models import normalize_text | |
| # ============================================================================= | |
| # Logger | |
| # ============================================================================= | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Type Checking Imports | |
| # ============================================================================= | |
| # These imports are only processed by type checkers (mypy, pyright) and IDEs. | |
| # They enable proper type hints without runtime overhead. | |
| # ============================================================================= | |
| if TYPE_CHECKING: | |
| from rank_bm25 import BM25Okapi | |
| # ============================================================================= | |
| # Module Exports | |
| # ============================================================================= | |
| __all__: list[str] = ["BM25Retriever"] | |
| # ============================================================================= | |
| # Lazy Loading for Heavy Dependencies | |
| # ============================================================================= | |
| # The rank_bm25 library is loaded lazily on first use. This pattern ensures: | |
| # - Fast import times when BM25 is not needed | |
| # - Minimal memory usage until retrieval starts | |
| # - Compatibility with environments without rank_bm25 installed | |
| # ============================================================================= | |
| # Global variable to cache the lazily-loaded rank_bm25 module | |
| # Using None as sentinel value to indicate "not yet loaded" | |
| _bm25_module: ModuleType | None = None | |
| def _get_bm25_module() -> ModuleType: | |
| """Lazily import and cache the rank_bm25 module. | |
| This function implements lazy loading for the rank_bm25 dependency. | |
| On first call, it imports the module and caches it globally. Subsequent | |
| calls return the cached module without re-importing. | |
| The lazy loading pattern ensures that the heavy dependency is only loaded | |
| when BM25 functionality is actually needed, improving startup time for | |
| applications that may not use BM25 retrieval. | |
| Returns: | |
| ------- | |
| The rank_bm25 module, cached for subsequent calls. | |
| Raises: | |
| ------ | |
| ImportError: If rank_bm25 is not installed. Install with: | |
| pip install rank-bm25 | |
| or | |
| poetry add rank-bm25 | |
| Example: | |
| ------- | |
| >>> bm25 = _get_bm25_module() | |
| >>> index = bm25.BM25Okapi(tokenized_corpus) | |
| """ | |
| global _bm25_module # noqa: PLW0603 | |
| # Return cached module if already loaded | |
| if _bm25_module is not None: | |
| return _bm25_module | |
| # Import and cache the module on first use. | |
| # This may take a moment as rank_bm25 loads numpy dependencies. | |
| import rank_bm25 as bm25 | |
| _bm25_module = bm25 | |
| return _bm25_module | |
| # ============================================================================= | |
| # Text Processing Utilities | |
| # ============================================================================= | |
| # These functions handle text normalization and tokenization for BM25 indexing. | |
| # Proper text processing is critical for effective keyword matching. | |
| # ============================================================================= | |
| # Pre-compile regex pattern for punctuation removal | |
| # This is more efficient than using str.translate for each call | |
| # Matches any punctuation character from string.punctuation | |
| _PUNCTUATION_PATTERN: re.Pattern[str] = re.compile(f"[{re.escape(string.punctuation)}]") | |
| def _tokenize(text: str) -> list[str]: | |
| """Tokenize text for BM25 indexing. | |
| This function performs the following text processing steps: | |
| 1. Normalize text using normalize_text (fix whitespace, capitalization) | |
| 2. Convert to lowercase for case-insensitive matching | |
| 3. Remove punctuation (commas, periods, etc.) | |
| 4. Split on whitespace into tokens | |
| 5. Filter out empty tokens | |
| The tokenization strategy is intentionally simple (whitespace splitting) | |
| because BM25 works well with basic tokenization, and more sophisticated | |
| tokenization (stemming, lemmatization) can sometimes hurt retrieval | |
| performance for technical documentation. | |
| Args: | |
| ---- | |
| text: The text string to tokenize. | |
| Can contain any UTF-8 characters including Unicode. | |
| Returns: | |
| ------- | |
| List of lowercase tokens with punctuation removed. | |
| Empty list if text is empty or whitespace-only. | |
| Example: | |
| ------- | |
| >>> _tokenize("Hello, World!") | |
| ['hello', 'world'] | |
| >>> _tokenize("The PMV model is 25.5 degrees.") | |
| ['the', 'pmv', 'model', 'is', '255', 'degrees'] | |
| >>> _tokenize(" ") | |
| [] | |
| Note: | |
| ---- | |
| - Numbers are preserved (not removed) to support queries like "ISO 7730" | |
| - Unicode characters are preserved for international text support | |
| - Contractions like "don't" become "dont" (apostrophe removed) | |
| """ | |
| # Step 1: Apply text normalization from models.py | |
| # This fixes extra whitespace, capitalization after periods, etc. | |
| normalized = normalize_text(text) | |
| # Handle empty text after normalization | |
| if not normalized: | |
| return [] | |
| # Step 2: Convert to lowercase for case-insensitive matching | |
| # This ensures "PMV" and "pmv" are treated as the same term | |
| lowercased = normalized.lower() | |
| # Step 3: Remove punctuation using pre-compiled regex | |
| # This converts "Hello, world!" to "Hello world" | |
| # Punctuation can interfere with term matching | |
| without_punctuation = _PUNCTUATION_PATTERN.sub("", lowercased) | |
| # Step 4: Split on whitespace | |
| # This creates individual tokens from the cleaned text | |
| # Using split() without arguments splits on any whitespace and removes empty strings | |
| tokens = without_punctuation.split() | |
| # Step 5: Filter out any remaining empty tokens (defensive) | |
| # The split() above should handle this, but being explicit is safer | |
| return [token for token in tokens if token] | |
| # ============================================================================= | |
| # Score Normalization Utilities | |
| # ============================================================================= | |
| # BM25 raw scores are unbounded positive values. We normalize to [0, 1] | |
| # for consistency with dense retrieval and the RetrievalResult model. | |
| # ============================================================================= | |
| def _normalize_scores(scores: list[float]) -> list[float]: | |
| """Normalize BM25 scores to [0, 1] range using min-max normalization. | |
| BM25 raw scores are positive values that can be arbitrarily large depending | |
| on term frequency, document length, and corpus statistics. This function | |
| normalizes them to [0, 1] range for consistency with other retrievers. | |
| Normalization Formula: | |
| normalized = (score - min_score) / (max_score - min_score) | |
| This maps: | |
| - Minimum score -> 0.0 | |
| - Maximum score -> 1.0 | |
| - All other scores -> proportionally between 0 and 1 | |
| Edge Cases: | |
| - Empty list: Returns empty list | |
| - Single value: Returns [1.0] (the only result is the "best") | |
| - All same values: Returns all 1.0 (all equally relevant) | |
| - Max score is 0: Returns all 0.0 (no relevance detected) | |
| Args: | |
| ---- | |
| scores: List of raw BM25 scores (non-negative floats). | |
| Returns: | |
| ------- | |
| List of normalized scores in [0.0, 1.0] range. | |
| Example: | |
| ------- | |
| >>> _normalize_scores([0.0, 0.5, 1.0]) | |
| [0.0, 0.5, 1.0] | |
| >>> _normalize_scores([2.0, 4.0, 6.0]) | |
| [0.0, 0.5, 1.0] | |
| >>> _normalize_scores([5.0, 5.0, 5.0]) | |
| [1.0, 1.0, 1.0] | |
| >>> _normalize_scores([]) | |
| [] | |
| """ | |
| # Handle empty list | |
| if not scores: | |
| return [] | |
| # Find min and max scores for normalization | |
| min_score = min(scores) | |
| max_score = max(scores) | |
| # Calculate the range for normalization | |
| score_range = max_score - min_score | |
| # ================================================================= | |
| # Edge case: All scores are the same (range is 0) | |
| # ================================================================= | |
| # When all documents have the same score, they're equally relevant. | |
| # We return 1.0 for all to indicate "best available match". | |
| # ================================================================= | |
| if score_range == 0: | |
| # If max_score is also 0, no relevance was detected | |
| # Return 0.0 for all in this case | |
| if max_score == 0: | |
| return [0.0] * len(scores) | |
| # Otherwise, all scores are equal and non-zero | |
| # Return 1.0 for all (equally "best") | |
| return [1.0] * len(scores) | |
| # Normal case: Apply min-max normalization. | |
| # Maps min -> 0.0 and max -> 1.0 | |
| return [(score - min_score) / score_range for score in scores] | |
| # ============================================================================= | |
| # BM25Retriever Class | |
| # ============================================================================= | |
| class BM25Retriever: | |
| """BM25-based sparse retriever for keyword search. | |
| This class implements BM25 (Best Match 25) retrieval, a probabilistic ranking | |
| function widely used for information retrieval. BM25 complements dense | |
| embeddings by handling exact keyword matches and out-of-vocabulary terms | |
| that may not be well represented in embedding space. | |
| The BM25Okapi variant is used (from rank_bm25 library), which implements | |
| the standard BM25 scoring with Okapi weighting. Key parameters: | |
| Parameters | |
| ---------- | |
| k1 : float | |
| Term frequency saturation parameter. Controls how quickly term | |
| frequency reaches saturation. Higher values give more weight to | |
| repeated terms. Typical range: 1.2 to 2.0. | |
| - k1 = 0: Binary term presence (TF ignored) | |
| - k1 = 1.5 (default): Standard BM25 setting | |
| - k1 = 3+: Very high TF weight | |
| b : float | |
| Document length normalization parameter. Controls how much | |
| document length affects scoring. Range: 0.0 to 1.0. | |
| - b = 0: No length normalization (long docs not penalized) | |
| - b = 0.75 (default): Standard BM25 setting | |
| - b = 1: Full length normalization | |
| Lazy Loading: | |
| The rank_bm25 library is loaded on first use (build or load) to | |
| avoid import overhead when BM25 is not needed. | |
| Thread Safety: | |
| This class is NOT thread-safe. For concurrent access, use separate | |
| instances or external synchronization. | |
| Attributes | |
| ---------- | |
| _k1 : float | |
| BM25 k1 parameter (term frequency saturation). | |
| _b : float | |
| BM25 b parameter (document length normalization). | |
| _bm25 : BM25Okapi | None | |
| The BM25 index (None until build() is called). | |
| _tokenized_corpus : list[list[str]] | None | |
| Tokenized documents (None until build() is called). | |
| _chunk_ids : list[str] | None | |
| Chunk identifiers mapping indices to IDs. | |
| Example | |
| ------- | |
| >>> retriever = BM25Retriever(k1=1.5, b=0.75) | |
| >>> retriever.build( | |
| ... corpus=["The PMV model predicts thermal sensation."], | |
| ... chunk_ids=["chunk_001"] | |
| ... ) | |
| >>> results = retriever.retrieve("PMV model", top_k=5) | |
| >>> chunk_id, score = results[0] | |
| >>> print(f"Best match: {chunk_id} with score {score:.3f}") | |
| See Also | |
| -------- | |
| - https://en.wikipedia.org/wiki/Okapi_BM25 | |
| - https://github.com/dorianbrown/rank_bm25 | |
| """ | |
| def __init__( | |
| self, | |
| k1: float = 1.5, | |
| b: float = 0.75, | |
| ) -> None: | |
| """Initialize the BM25 retriever with configurable parameters. | |
| Creates a new BM25Retriever instance with the specified BM25 parameters. | |
| The index is NOT built during initialization - call build() to create | |
| the index, or load() to restore a saved index. | |
| This follows the lazy loading pattern: no heavy dependencies are loaded | |
| during __init__. The rank_bm25 library is only imported when build() | |
| or load() is called. | |
| Args: | |
| ---- | |
| k1: Term frequency saturation parameter. Higher values give more | |
| weight to term frequency. Must be non-negative. | |
| Defaults to 1.5 (standard BM25 setting). | |
| b: Document length normalization parameter. 0 means no | |
| normalization, 1 means full normalization. Should be | |
| in [0, 1] range. Defaults to 0.75 (standard BM25 setting). | |
| Example: | |
| ------- | |
| >>> # Default parameters | |
| >>> retriever = BM25Retriever() | |
| >>> # Custom parameters for short documents | |
| >>> retriever = BM25Retriever(k1=1.2, b=0.5) | |
| >>> # High term frequency weight | |
| >>> retriever = BM25Retriever(k1=2.5, b=0.75) | |
| Note: | |
| ---- | |
| - The retriever is not usable until build() or load() is called | |
| - No validation is performed on k1/b ranges (rank_bm25 handles this) | |
| """ | |
| # ================================================================= | |
| # Store BM25 parameters | |
| # ================================================================= | |
| # These parameters are used when building the BM25Okapi index | |
| # k1: Controls term frequency saturation (default 1.5) | |
| # b: Controls document length normalization (default 0.75) | |
| # ================================================================= | |
| self._k1: float = k1 | |
| self._b: float = b | |
| # ================================================================= | |
| # Initialize state as None (not yet built) | |
| # ================================================================= | |
| # The BM25 index and related data structures are created in build() | |
| # or restored in load(). Until then, these are None. | |
| # ================================================================= | |
| # The BM25Okapi index from rank_bm25 | |
| # This is the core data structure for BM25 scoring | |
| self._bm25: BM25Okapi | None = None | |
| # Tokenized version of the corpus | |
| # Stored for persistence (BM25Okapi is not directly picklable) | |
| self._tokenized_corpus: list[list[str]] | None = None | |
| # Mapping from corpus indices to chunk IDs | |
| # Used to return chunk_ids in retrieve() results | |
| self._chunk_ids: list[str] | None = None | |
| # ========================================================================= | |
| # Private Helper Methods | |
| # ========================================================================= | |
| def _is_built(self) -> bool: | |
| """Check if the BM25 index has been built. | |
| This helper method checks whether the retriever has been initialized | |
| with a corpus (via build() or load()). Used for validation before | |
| operations that require a built index. | |
| Returns | |
| ------- | |
| bool | |
| True if the index is built and ready for retrieval. | |
| False if build() or load() has not been called yet. | |
| """ | |
| return ( | |
| self._bm25 is not None | |
| and self._chunk_ids is not None | |
| and self._tokenized_corpus is not None | |
| ) | |
| def chunk_ids(self) -> list[str]: | |
| """Get the list of chunk IDs in index order. | |
| Returns the chunk IDs that were used to build this index. The order | |
| matches the order in which documents were indexed, which is important | |
| for coordinating with other indexes (e.g., FAISS) that use the same | |
| ordering. | |
| Returns | |
| ------- | |
| List of chunk ID strings in index order. | |
| Raises | |
| ------ | |
| RuntimeError: If the index has not been built yet. | |
| """ | |
| if self._chunk_ids is None: | |
| msg = "Index not built - call build() or load() first" | |
| raise RuntimeError(msg) | |
| return self._chunk_ids | |
| # ========================================================================= | |
| # Public Methods | |
| # ========================================================================= | |
| def build( | |
| self, | |
| corpus: list[str], | |
| chunk_ids: list[str], | |
| ) -> None: | |
| """Build the BM25 index from a corpus of documents. | |
| This method creates the BM25 index by: | |
| 1. Validating input parameters | |
| 2. Tokenizing each document in the corpus | |
| 3. Building the BM25Okapi index with the tokenized corpus | |
| 4. Storing chunk_ids for mapping indices to identifiers | |
| The build process is idempotent - calling build() multiple times | |
| replaces the previous index with a new one. | |
| Args: | |
| ---- | |
| corpus: List of document texts to index. | |
| Each string is a document that will be tokenized and indexed. | |
| Documents are normalized (whitespace, case) during tokenization. | |
| chunk_ids: List of unique chunk identifiers. | |
| Must have the same length as corpus. | |
| Used to identify documents in retrieve() results. | |
| Raises: | |
| ------ | |
| ValueError: If corpus is empty. | |
| ValueError: If corpus and chunk_ids have different lengths. | |
| ValueError: If all documents are empty after tokenization. | |
| Example: | |
| ------- | |
| >>> retriever = BM25Retriever() | |
| >>> corpus = [ | |
| ... "The PMV model predicts thermal sensation.", | |
| ... "Thermal comfort depends on air temperature.", | |
| ... ] | |
| >>> chunk_ids = ["chunk_001", "chunk_002"] | |
| >>> retriever.build(corpus, chunk_ids) | |
| >>> # Index is now ready for retrieval | |
| Note: | |
| ---- | |
| - Documents are tokenized (lowercase, punctuation removed, split) | |
| - Empty documents after tokenization are preserved in the index | |
| but will not match any queries | |
| - The rank_bm25 library is loaded on first call to build() | |
| """ | |
| # ================================================================= | |
| # Step 1: Validate corpus is not empty | |
| # ================================================================= | |
| if not corpus: | |
| msg = "Cannot build BM25 index with empty corpus" | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 2: Validate lengths match | |
| # ================================================================= | |
| if len(corpus) != len(chunk_ids): | |
| msg = ( | |
| f"corpus and chunk_ids length mismatch: " | |
| f"{len(corpus)} documents but {len(chunk_ids)} chunk_ids" | |
| ) | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 3: Tokenize all documents | |
| # ================================================================= | |
| # Each document is normalized and tokenized for BM25 indexing | |
| # The tokenized corpus is stored for persistence | |
| # ================================================================= | |
| tokenized_corpus: list[list[str]] = [_tokenize(doc) for doc in corpus] | |
| # ================================================================= | |
| # Step 4: Validate that at least some documents have tokens | |
| # ================================================================= | |
| # If ALL documents are empty after tokenization, the index is useless | |
| # This catches cases like corpus = [" ", "\t\n", " \t "] | |
| # ================================================================= | |
| if all(len(tokens) == 0 for tokens in tokenized_corpus): | |
| msg = ( | |
| "All documents are empty after tokenization. " | |
| "Cannot build BM25 index with no terms." | |
| ) | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 5: Get the BM25 module (lazy load) | |
| # ================================================================= | |
| # This is the first point where rank_bm25 is actually needed | |
| # The module is cached globally after first import | |
| # ================================================================= | |
| bm25_module = _get_bm25_module() | |
| # ================================================================= | |
| # Step 6: Build the BM25Okapi index | |
| # ================================================================= | |
| # BM25Okapi is initialized with the tokenized corpus | |
| # The k1 and b parameters control scoring behavior | |
| # ================================================================= | |
| self._bm25 = bm25_module.BM25Okapi( | |
| corpus=tokenized_corpus, | |
| k1=self._k1, | |
| b=self._b, | |
| ) | |
| # ================================================================= | |
| # Step 7: Store the tokenized corpus and chunk_ids | |
| # ================================================================= | |
| # These are needed for: | |
| # - _tokenized_corpus: persistence (save/load) | |
| # - _chunk_ids: mapping indices to chunk identifiers | |
| # ================================================================= | |
| self._tokenized_corpus = tokenized_corpus | |
| self._chunk_ids = chunk_ids | |
| def retrieve( | |
| self, | |
| query: str, | |
| top_k: int = 10, | |
| ) -> list[tuple[str, float]]: | |
| """Retrieve the most relevant documents for a query. | |
| This method searches the BM25 index for documents matching the query | |
| and returns the top-k results sorted by relevance score. | |
| Processing Steps: | |
| 1. Validate that the index has been built | |
| 2. Validate query and top_k parameters | |
| 3. Tokenize the query (same process as document tokenization) | |
| 4. Score all documents using BM25 | |
| 5. Select top-k highest scoring documents | |
| 6. Normalize scores to [0, 1] range | |
| 7. Return results as (chunk_id, score) tuples | |
| Args: | |
| ---- | |
| query: The search query string. | |
| Will be tokenized using the same process as documents. | |
| Must not be empty or whitespace-only. | |
| top_k: Maximum number of results to return. Defaults to 10. | |
| Must be a positive integer. | |
| If top_k exceeds corpus size, all documents are returned. | |
| Returns: | |
| ------- | |
| List of (chunk_id, score) tuples sorted by score descending. | |
| Scores are normalized to [0.0, 1.0] range. | |
| Returns at most min(top_k, corpus_size) results. | |
| Raises: | |
| ------ | |
| RuntimeError: If retrieve() is called before build() or load(). | |
| ValueError: If query is empty or whitespace-only. | |
| ValueError: If top_k is not a positive integer. | |
| Example: | |
| ------- | |
| >>> results = retriever.retrieve("thermal comfort PMV", top_k=5) | |
| >>> for chunk_id, score in results: | |
| ... print(f"{chunk_id}: {score:.3f}") | |
| chunk_001: 0.923 | |
| chunk_003: 0.756 | |
| chunk_002: 0.534 | |
| Note: | |
| ---- | |
| - Query tokenization mirrors document tokenization (lowercase, | |
| no punctuation, whitespace split) | |
| - If query contains no matching terms, results will have score 0.0 | |
| - Results are always sorted by score descending (best first) | |
| """ | |
| # ================================================================= | |
| # Step 1: Validate index is built | |
| # ================================================================= | |
| if not self._is_built(): | |
| msg = "BM25 index not built. Call build() or load() first." | |
| raise RuntimeError(msg) | |
| # ================================================================= | |
| # Step 2: Validate top_k parameter | |
| # ================================================================= | |
| if not isinstance(top_k, int) or top_k <= 0: | |
| msg = f"top_k must be a positive integer, got {top_k}" | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 3: Validate and tokenize query | |
| # ================================================================= | |
| # Check for empty query before tokenization | |
| if not query or not query.strip(): | |
| msg = "query cannot be empty or whitespace-only" | |
| raise ValueError(msg) | |
| # Tokenize query using same process as documents | |
| query_tokens = _tokenize(query) | |
| # Check for empty query after tokenization | |
| # This can happen if query only contains punctuation | |
| if not query_tokens: | |
| msg = "query is empty after tokenization (no valid terms)" | |
| raise ValueError(msg) | |
| # ================================================================= | |
| # Step 4: Get BM25 scores for all documents | |
| # ================================================================= | |
| # get_scores returns a numpy array with score for each document | |
| # Type narrowing for mypy (we know _bm25 is not None after _is_built check) | |
| assert self._bm25 is not None | |
| assert self._chunk_ids is not None | |
| # Get raw BM25 scores (numpy array) | |
| raw_scores = self._bm25.get_scores(query_tokens) | |
| # ================================================================= | |
| # Step 5: Create (index, score) pairs and sort by score descending | |
| # ================================================================= | |
| # We need to track indices to map back to chunk_ids | |
| # Convert numpy array to list for processing | |
| indexed_scores: list[tuple[int, float]] = [ | |
| (idx, float(score)) for idx, score in enumerate(raw_scores) | |
| ] | |
| # Sort by score descending (highest first) | |
| indexed_scores.sort(key=lambda x: x[1], reverse=True) | |
| # ================================================================= | |
| # Step 6: Select top-k results | |
| # ================================================================= | |
| # Limit to top_k, but don't exceed corpus size | |
| top_k_results = indexed_scores[:top_k] | |
| # ================================================================= | |
| # Step 7: Normalize scores to [0, 1] range | |
| # ================================================================= | |
| # Extract scores for normalization | |
| scores_only = [score for _, score in top_k_results] | |
| normalized_scores = _normalize_scores(scores_only) | |
| # ================================================================= | |
| # Step 8: Build final results with chunk_ids | |
| # ================================================================= | |
| # Map indices to chunk_ids and pair with normalized scores | |
| results: list[tuple[str, float]] = [ | |
| (self._chunk_ids[idx], norm_score) | |
| for (idx, _), norm_score in zip( | |
| top_k_results, normalized_scores, strict=True | |
| ) | |
| ] | |
| return results | |
| def save(self, path: Path) -> None: | |
| """Save the BM25 index to disk for later restoration. | |
| Persists the BM25 index state using pickle. The saved data includes: | |
| - k1, b parameters (for rebuilding BM25Okapi) | |
| - Tokenized corpus (list of token lists) | |
| - Chunk IDs (for result mapping) | |
| Note that the BM25Okapi object itself is not pickled directly because | |
| it may have compatibility issues. Instead, we save the tokenized corpus | |
| and rebuild the BM25Okapi index on load(). | |
| Parent directories are created if they don't exist. | |
| Args: | |
| ---- | |
| path: File path to save the index. | |
| Should typically have .pkl extension. | |
| Parent directories will be created if needed. | |
| Raises: | |
| ------ | |
| RuntimeError: If save() is called before build(). | |
| Example: | |
| ------- | |
| >>> retriever = BM25Retriever() | |
| >>> retriever.build(corpus, chunk_ids) | |
| >>> retriever.save(Path("indexes/bm25_index.pkl")) | |
| Note: | |
| ---- | |
| - The saved file can be restored with BM25Retriever.load() | |
| - Pickle format is used; ensure trusted data sources only | |
| - File size depends on corpus size (tokenized text is stored) | |
| """ | |
| # ================================================================= | |
| # Step 1: Validate index is built | |
| # ================================================================= | |
| if not self._is_built(): | |
| msg = "Cannot save unbuilt BM25 index. Call build() first." | |
| raise RuntimeError(msg) | |
| # ================================================================= | |
| # Step 2: Create parent directories if needed | |
| # ================================================================= | |
| # This ensures save() works even for nested paths that don't exist | |
| # ================================================================= | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # ================================================================= | |
| # Step 3: Prepare data for persistence | |
| # ================================================================= | |
| # We save all the data needed to rebuild the BM25 index: | |
| # - k1, b: BM25 parameters | |
| # - tokenized_corpus: Pre-tokenized documents | |
| # - chunk_ids: Document identifiers | |
| # | |
| # The BM25Okapi object is NOT saved directly because: | |
| # - It may have numpy arrays that complicate pickling | |
| # - Rebuilding from tokenized_corpus is straightforward | |
| # ================================================================= | |
| save_data: dict[str, Any] = { | |
| "k1": self._k1, | |
| "b": self._b, | |
| "tokenized_corpus": self._tokenized_corpus, | |
| "chunk_ids": self._chunk_ids, | |
| } | |
| # ================================================================= | |
| # Step 4: Write to disk using pickle | |
| # ================================================================= | |
| with path.open("wb") as f: | |
| pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| def load(cls, path: Path) -> BM25Retriever: | |
| """Load a BM25 index from disk. | |
| Restores a BM25Retriever from a previously saved index file. The | |
| BM25Okapi index is rebuilt from the saved tokenized corpus using | |
| the saved k1 and b parameters. | |
| Args: | |
| ---- | |
| path: File path to load the index from. | |
| Must be a file created by save(). | |
| Returns: | |
| ------- | |
| A new BM25Retriever instance with the restored index. | |
| Raises: | |
| ------ | |
| FileNotFoundError: If the path does not exist. | |
| Example: | |
| ------- | |
| >>> retriever = BM25Retriever.load(Path("indexes/bm25_index.pkl")) | |
| >>> results = retriever.retrieve("thermal comfort", top_k=5) | |
| Note: | |
| ---- | |
| - The returned retriever is immediately usable for retrieval | |
| - The rank_bm25 library is loaded during this operation | |
| - Pickle format is used; only load files from trusted sources | |
| """ | |
| # ================================================================= | |
| # Step 1: Validate path exists | |
| # ================================================================= | |
| if not path.exists(): | |
| msg = f"BM25 index file not found: {path}" | |
| raise FileNotFoundError(msg) | |
| # Step 2: Load saved data from pickle. | |
| # Note: Only load from trusted sources as pickle can execute code. | |
| with path.open("rb") as f: | |
| save_data: dict[str, Any] = pickle.load(f) | |
| # ================================================================= | |
| # Step 3: Handle different pickle formats | |
| # ================================================================= | |
| # There are two possible formats: | |
| # - Build pipeline format: {"bm25": BM25Okapi, "chunk_ids": list} | |
| # - Retrieval save format: {"k1", "b", "tokenized_corpus", "chunk_ids"} | |
| # ================================================================= | |
| chunk_ids: list[str] = save_data["chunk_ids"] | |
| if "bm25" in save_data: | |
| # Format from build pipeline (embeddings/indexing.py) | |
| # The BM25Okapi object is stored directly | |
| bm25_index = save_data["bm25"] | |
| # Extract k1 and b from the loaded BM25 object | |
| k1 = getattr(bm25_index, "k1", 1.5) | |
| b = getattr(bm25_index, "b", 0.75) | |
| # Create retriever with extracted parameters | |
| retriever = cls(k1=k1, b=b) | |
| retriever._bm25 = bm25_index | |
| # Try to get tokenized corpus from the BM25 object | |
| # rank_bm25 stores this internally for scoring | |
| retriever._tokenized_corpus = getattr(bm25_index, "corpus", []) | |
| logger.debug( | |
| "Loaded BM25 from build pipeline format: %d chunks", | |
| len(chunk_ids), | |
| ) | |
| else: | |
| # Format from retrieval module's save() method | |
| k1 = save_data["k1"] | |
| b = save_data["b"] | |
| tokenized_corpus: list[list[str]] = save_data["tokenized_corpus"] | |
| # Create new retriever with saved parameters | |
| retriever = cls(k1=k1, b=b) | |
| # Get the BM25 module (lazy load) | |
| bm25_module = _get_bm25_module() | |
| # Rebuild BM25Okapi index from tokenized corpus | |
| retriever._bm25 = bm25_module.BM25Okapi( | |
| corpus=tokenized_corpus, | |
| k1=k1, | |
| b=b, | |
| ) | |
| retriever._tokenized_corpus = tokenized_corpus | |
| logger.debug( | |
| "Loaded BM25 from retrieval save format: %d chunks", | |
| len(chunk_ids), | |
| ) | |
| # ================================================================= | |
| # Step 4: Set chunk_ids | |
| # ================================================================= | |
| retriever._chunk_ids = chunk_ids | |
| return retriever | |