sadickam's picture
Add missing logger import to BM25Retriever
3742372
"""BM25 sparse retrieval for keyword-based search.
This module provides the BM25Retriever class for sparse retrieval using the
BM25 (Best Match 25) algorithm. BM25 is a probabilistic ranking function that
scores documents based on term frequency (TF) and inverse document frequency (IDF).
BM25 is particularly effective for:
- Keyword matching: Exact term retrieval where semantic similarity may fail
- Out-of-vocabulary terms: Technical terms or acronyms not in embedding vocab
- Hybrid retrieval: Complementing dense embeddings with sparse signals
The BM25 scoring formula is:
score(D,Q) = sum_{i=1}^{n} IDF(q_i) * (f(q_i,D) * (k1+1)) /
(f(q_i,D) + k1 * (1-b+b*|D|/avgdl))
Where:
- f(q_i,D) = term frequency of query term q_i in document D
- |D| = length of document D in words
- avgdl = average document length across the corpus
- k1 = term frequency saturation parameter (default: 1.5)
- b = document length normalization parameter (default: 0.75)
- IDF(q_i) = log((N - n(q_i) + 0.5) / (n(q_i) + 0.5))
- N = total number of documents
- n(q_i) = number of documents containing term q_i
Design Decisions:
- Lazy loading: rank_bm25 is imported on first use to avoid overhead
- Text normalization: Uses normalize_text from models.py plus tokenization
- Score normalization: Raw BM25 scores normalized to [0, 1] using min-max
- Persistence: Index saved via pickle with tokenized corpus (not BM25 object)
Lazy Loading:
The rank_bm25 library is loaded on first use (build or load) to avoid
import overhead when BM25 retrieval is not needed. This follows the
project convention for heavy dependencies.
Example:
-------
>>> from rag_chatbot.retrieval import BM25Retriever
>>> # Build index from corpus
>>> retriever = BM25Retriever(k1=1.5, b=0.75)
>>> retriever.build(corpus=["doc1 text", "doc2 text"], chunk_ids=["c1", "c2"])
>>> # Retrieve
>>> results = retriever.retrieve("search query", top_k=5)
>>> for chunk_id, score in results:
... print(f"{chunk_id}: {score:.3f}")
"""
from __future__ import annotations
import logging
import pickle
import re
import string
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any
# Import the normalize_text function from models (lightweight, no heavy deps)
from .models import normalize_text
# =============================================================================
# Logger
# =============================================================================
logger = logging.getLogger(__name__)
# =============================================================================
# Type Checking Imports
# =============================================================================
# These imports are only processed by type checkers (mypy, pyright) and IDEs.
# They enable proper type hints without runtime overhead.
# =============================================================================
if TYPE_CHECKING:
from rank_bm25 import BM25Okapi
# =============================================================================
# Module Exports
# =============================================================================
__all__: list[str] = ["BM25Retriever"]
# =============================================================================
# Lazy Loading for Heavy Dependencies
# =============================================================================
# The rank_bm25 library is loaded lazily on first use. This pattern ensures:
# - Fast import times when BM25 is not needed
# - Minimal memory usage until retrieval starts
# - Compatibility with environments without rank_bm25 installed
# =============================================================================
# Global variable to cache the lazily-loaded rank_bm25 module
# Using None as sentinel value to indicate "not yet loaded"
_bm25_module: ModuleType | None = None
def _get_bm25_module() -> ModuleType:
"""Lazily import and cache the rank_bm25 module.
This function implements lazy loading for the rank_bm25 dependency.
On first call, it imports the module and caches it globally. Subsequent
calls return the cached module without re-importing.
The lazy loading pattern ensures that the heavy dependency is only loaded
when BM25 functionality is actually needed, improving startup time for
applications that may not use BM25 retrieval.
Returns:
-------
The rank_bm25 module, cached for subsequent calls.
Raises:
------
ImportError: If rank_bm25 is not installed. Install with:
pip install rank-bm25
or
poetry add rank-bm25
Example:
-------
>>> bm25 = _get_bm25_module()
>>> index = bm25.BM25Okapi(tokenized_corpus)
"""
global _bm25_module # noqa: PLW0603
# Return cached module if already loaded
if _bm25_module is not None:
return _bm25_module
# Import and cache the module on first use.
# This may take a moment as rank_bm25 loads numpy dependencies.
import rank_bm25 as bm25
_bm25_module = bm25
return _bm25_module
# =============================================================================
# Text Processing Utilities
# =============================================================================
# These functions handle text normalization and tokenization for BM25 indexing.
# Proper text processing is critical for effective keyword matching.
# =============================================================================
# Pre-compile regex pattern for punctuation removal
# This is more efficient than using str.translate for each call
# Matches any punctuation character from string.punctuation
_PUNCTUATION_PATTERN: re.Pattern[str] = re.compile(f"[{re.escape(string.punctuation)}]")
def _tokenize(text: str) -> list[str]:
"""Tokenize text for BM25 indexing.
This function performs the following text processing steps:
1. Normalize text using normalize_text (fix whitespace, capitalization)
2. Convert to lowercase for case-insensitive matching
3. Remove punctuation (commas, periods, etc.)
4. Split on whitespace into tokens
5. Filter out empty tokens
The tokenization strategy is intentionally simple (whitespace splitting)
because BM25 works well with basic tokenization, and more sophisticated
tokenization (stemming, lemmatization) can sometimes hurt retrieval
performance for technical documentation.
Args:
----
text: The text string to tokenize.
Can contain any UTF-8 characters including Unicode.
Returns:
-------
List of lowercase tokens with punctuation removed.
Empty list if text is empty or whitespace-only.
Example:
-------
>>> _tokenize("Hello, World!")
['hello', 'world']
>>> _tokenize("The PMV model is 25.5 degrees.")
['the', 'pmv', 'model', 'is', '255', 'degrees']
>>> _tokenize(" ")
[]
Note:
----
- Numbers are preserved (not removed) to support queries like "ISO 7730"
- Unicode characters are preserved for international text support
- Contractions like "don't" become "dont" (apostrophe removed)
"""
# Step 1: Apply text normalization from models.py
# This fixes extra whitespace, capitalization after periods, etc.
normalized = normalize_text(text)
# Handle empty text after normalization
if not normalized:
return []
# Step 2: Convert to lowercase for case-insensitive matching
# This ensures "PMV" and "pmv" are treated as the same term
lowercased = normalized.lower()
# Step 3: Remove punctuation using pre-compiled regex
# This converts "Hello, world!" to "Hello world"
# Punctuation can interfere with term matching
without_punctuation = _PUNCTUATION_PATTERN.sub("", lowercased)
# Step 4: Split on whitespace
# This creates individual tokens from the cleaned text
# Using split() without arguments splits on any whitespace and removes empty strings
tokens = without_punctuation.split()
# Step 5: Filter out any remaining empty tokens (defensive)
# The split() above should handle this, but being explicit is safer
return [token for token in tokens if token]
# =============================================================================
# Score Normalization Utilities
# =============================================================================
# BM25 raw scores are unbounded positive values. We normalize to [0, 1]
# for consistency with dense retrieval and the RetrievalResult model.
# =============================================================================
def _normalize_scores(scores: list[float]) -> list[float]:
"""Normalize BM25 scores to [0, 1] range using min-max normalization.
BM25 raw scores are positive values that can be arbitrarily large depending
on term frequency, document length, and corpus statistics. This function
normalizes them to [0, 1] range for consistency with other retrievers.
Normalization Formula:
normalized = (score - min_score) / (max_score - min_score)
This maps:
- Minimum score -> 0.0
- Maximum score -> 1.0
- All other scores -> proportionally between 0 and 1
Edge Cases:
- Empty list: Returns empty list
- Single value: Returns [1.0] (the only result is the "best")
- All same values: Returns all 1.0 (all equally relevant)
- Max score is 0: Returns all 0.0 (no relevance detected)
Args:
----
scores: List of raw BM25 scores (non-negative floats).
Returns:
-------
List of normalized scores in [0.0, 1.0] range.
Example:
-------
>>> _normalize_scores([0.0, 0.5, 1.0])
[0.0, 0.5, 1.0]
>>> _normalize_scores([2.0, 4.0, 6.0])
[0.0, 0.5, 1.0]
>>> _normalize_scores([5.0, 5.0, 5.0])
[1.0, 1.0, 1.0]
>>> _normalize_scores([])
[]
"""
# Handle empty list
if not scores:
return []
# Find min and max scores for normalization
min_score = min(scores)
max_score = max(scores)
# Calculate the range for normalization
score_range = max_score - min_score
# =================================================================
# Edge case: All scores are the same (range is 0)
# =================================================================
# When all documents have the same score, they're equally relevant.
# We return 1.0 for all to indicate "best available match".
# =================================================================
if score_range == 0:
# If max_score is also 0, no relevance was detected
# Return 0.0 for all in this case
if max_score == 0:
return [0.0] * len(scores)
# Otherwise, all scores are equal and non-zero
# Return 1.0 for all (equally "best")
return [1.0] * len(scores)
# Normal case: Apply min-max normalization.
# Maps min -> 0.0 and max -> 1.0
return [(score - min_score) / score_range for score in scores]
# =============================================================================
# BM25Retriever Class
# =============================================================================
class BM25Retriever:
"""BM25-based sparse retriever for keyword search.
This class implements BM25 (Best Match 25) retrieval, a probabilistic ranking
function widely used for information retrieval. BM25 complements dense
embeddings by handling exact keyword matches and out-of-vocabulary terms
that may not be well represented in embedding space.
The BM25Okapi variant is used (from rank_bm25 library), which implements
the standard BM25 scoring with Okapi weighting. Key parameters:
Parameters
----------
k1 : float
Term frequency saturation parameter. Controls how quickly term
frequency reaches saturation. Higher values give more weight to
repeated terms. Typical range: 1.2 to 2.0.
- k1 = 0: Binary term presence (TF ignored)
- k1 = 1.5 (default): Standard BM25 setting
- k1 = 3+: Very high TF weight
b : float
Document length normalization parameter. Controls how much
document length affects scoring. Range: 0.0 to 1.0.
- b = 0: No length normalization (long docs not penalized)
- b = 0.75 (default): Standard BM25 setting
- b = 1: Full length normalization
Lazy Loading:
The rank_bm25 library is loaded on first use (build or load) to
avoid import overhead when BM25 is not needed.
Thread Safety:
This class is NOT thread-safe. For concurrent access, use separate
instances or external synchronization.
Attributes
----------
_k1 : float
BM25 k1 parameter (term frequency saturation).
_b : float
BM25 b parameter (document length normalization).
_bm25 : BM25Okapi | None
The BM25 index (None until build() is called).
_tokenized_corpus : list[list[str]] | None
Tokenized documents (None until build() is called).
_chunk_ids : list[str] | None
Chunk identifiers mapping indices to IDs.
Example
-------
>>> retriever = BM25Retriever(k1=1.5, b=0.75)
>>> retriever.build(
... corpus=["The PMV model predicts thermal sensation."],
... chunk_ids=["chunk_001"]
... )
>>> results = retriever.retrieve("PMV model", top_k=5)
>>> chunk_id, score = results[0]
>>> print(f"Best match: {chunk_id} with score {score:.3f}")
See Also
--------
- https://en.wikipedia.org/wiki/Okapi_BM25
- https://github.com/dorianbrown/rank_bm25
"""
def __init__(
self,
k1: float = 1.5,
b: float = 0.75,
) -> None:
"""Initialize the BM25 retriever with configurable parameters.
Creates a new BM25Retriever instance with the specified BM25 parameters.
The index is NOT built during initialization - call build() to create
the index, or load() to restore a saved index.
This follows the lazy loading pattern: no heavy dependencies are loaded
during __init__. The rank_bm25 library is only imported when build()
or load() is called.
Args:
----
k1: Term frequency saturation parameter. Higher values give more
weight to term frequency. Must be non-negative.
Defaults to 1.5 (standard BM25 setting).
b: Document length normalization parameter. 0 means no
normalization, 1 means full normalization. Should be
in [0, 1] range. Defaults to 0.75 (standard BM25 setting).
Example:
-------
>>> # Default parameters
>>> retriever = BM25Retriever()
>>> # Custom parameters for short documents
>>> retriever = BM25Retriever(k1=1.2, b=0.5)
>>> # High term frequency weight
>>> retriever = BM25Retriever(k1=2.5, b=0.75)
Note:
----
- The retriever is not usable until build() or load() is called
- No validation is performed on k1/b ranges (rank_bm25 handles this)
"""
# =================================================================
# Store BM25 parameters
# =================================================================
# These parameters are used when building the BM25Okapi index
# k1: Controls term frequency saturation (default 1.5)
# b: Controls document length normalization (default 0.75)
# =================================================================
self._k1: float = k1
self._b: float = b
# =================================================================
# Initialize state as None (not yet built)
# =================================================================
# The BM25 index and related data structures are created in build()
# or restored in load(). Until then, these are None.
# =================================================================
# The BM25Okapi index from rank_bm25
# This is the core data structure for BM25 scoring
self._bm25: BM25Okapi | None = None
# Tokenized version of the corpus
# Stored for persistence (BM25Okapi is not directly picklable)
self._tokenized_corpus: list[list[str]] | None = None
# Mapping from corpus indices to chunk IDs
# Used to return chunk_ids in retrieve() results
self._chunk_ids: list[str] | None = None
# =========================================================================
# Private Helper Methods
# =========================================================================
def _is_built(self) -> bool:
"""Check if the BM25 index has been built.
This helper method checks whether the retriever has been initialized
with a corpus (via build() or load()). Used for validation before
operations that require a built index.
Returns
-------
bool
True if the index is built and ready for retrieval.
False if build() or load() has not been called yet.
"""
return (
self._bm25 is not None
and self._chunk_ids is not None
and self._tokenized_corpus is not None
)
@property
def chunk_ids(self) -> list[str]:
"""Get the list of chunk IDs in index order.
Returns the chunk IDs that were used to build this index. The order
matches the order in which documents were indexed, which is important
for coordinating with other indexes (e.g., FAISS) that use the same
ordering.
Returns
-------
List of chunk ID strings in index order.
Raises
------
RuntimeError: If the index has not been built yet.
"""
if self._chunk_ids is None:
msg = "Index not built - call build() or load() first"
raise RuntimeError(msg)
return self._chunk_ids
# =========================================================================
# Public Methods
# =========================================================================
def build(
self,
corpus: list[str],
chunk_ids: list[str],
) -> None:
"""Build the BM25 index from a corpus of documents.
This method creates the BM25 index by:
1. Validating input parameters
2. Tokenizing each document in the corpus
3. Building the BM25Okapi index with the tokenized corpus
4. Storing chunk_ids for mapping indices to identifiers
The build process is idempotent - calling build() multiple times
replaces the previous index with a new one.
Args:
----
corpus: List of document texts to index.
Each string is a document that will be tokenized and indexed.
Documents are normalized (whitespace, case) during tokenization.
chunk_ids: List of unique chunk identifiers.
Must have the same length as corpus.
Used to identify documents in retrieve() results.
Raises:
------
ValueError: If corpus is empty.
ValueError: If corpus and chunk_ids have different lengths.
ValueError: If all documents are empty after tokenization.
Example:
-------
>>> retriever = BM25Retriever()
>>> corpus = [
... "The PMV model predicts thermal sensation.",
... "Thermal comfort depends on air temperature.",
... ]
>>> chunk_ids = ["chunk_001", "chunk_002"]
>>> retriever.build(corpus, chunk_ids)
>>> # Index is now ready for retrieval
Note:
----
- Documents are tokenized (lowercase, punctuation removed, split)
- Empty documents after tokenization are preserved in the index
but will not match any queries
- The rank_bm25 library is loaded on first call to build()
"""
# =================================================================
# Step 1: Validate corpus is not empty
# =================================================================
if not corpus:
msg = "Cannot build BM25 index with empty corpus"
raise ValueError(msg)
# =================================================================
# Step 2: Validate lengths match
# =================================================================
if len(corpus) != len(chunk_ids):
msg = (
f"corpus and chunk_ids length mismatch: "
f"{len(corpus)} documents but {len(chunk_ids)} chunk_ids"
)
raise ValueError(msg)
# =================================================================
# Step 3: Tokenize all documents
# =================================================================
# Each document is normalized and tokenized for BM25 indexing
# The tokenized corpus is stored for persistence
# =================================================================
tokenized_corpus: list[list[str]] = [_tokenize(doc) for doc in corpus]
# =================================================================
# Step 4: Validate that at least some documents have tokens
# =================================================================
# If ALL documents are empty after tokenization, the index is useless
# This catches cases like corpus = [" ", "\t\n", " \t "]
# =================================================================
if all(len(tokens) == 0 for tokens in tokenized_corpus):
msg = (
"All documents are empty after tokenization. "
"Cannot build BM25 index with no terms."
)
raise ValueError(msg)
# =================================================================
# Step 5: Get the BM25 module (lazy load)
# =================================================================
# This is the first point where rank_bm25 is actually needed
# The module is cached globally after first import
# =================================================================
bm25_module = _get_bm25_module()
# =================================================================
# Step 6: Build the BM25Okapi index
# =================================================================
# BM25Okapi is initialized with the tokenized corpus
# The k1 and b parameters control scoring behavior
# =================================================================
self._bm25 = bm25_module.BM25Okapi(
corpus=tokenized_corpus,
k1=self._k1,
b=self._b,
)
# =================================================================
# Step 7: Store the tokenized corpus and chunk_ids
# =================================================================
# These are needed for:
# - _tokenized_corpus: persistence (save/load)
# - _chunk_ids: mapping indices to chunk identifiers
# =================================================================
self._tokenized_corpus = tokenized_corpus
self._chunk_ids = chunk_ids
def retrieve(
self,
query: str,
top_k: int = 10,
) -> list[tuple[str, float]]:
"""Retrieve the most relevant documents for a query.
This method searches the BM25 index for documents matching the query
and returns the top-k results sorted by relevance score.
Processing Steps:
1. Validate that the index has been built
2. Validate query and top_k parameters
3. Tokenize the query (same process as document tokenization)
4. Score all documents using BM25
5. Select top-k highest scoring documents
6. Normalize scores to [0, 1] range
7. Return results as (chunk_id, score) tuples
Args:
----
query: The search query string.
Will be tokenized using the same process as documents.
Must not be empty or whitespace-only.
top_k: Maximum number of results to return. Defaults to 10.
Must be a positive integer.
If top_k exceeds corpus size, all documents are returned.
Returns:
-------
List of (chunk_id, score) tuples sorted by score descending.
Scores are normalized to [0.0, 1.0] range.
Returns at most min(top_k, corpus_size) results.
Raises:
------
RuntimeError: If retrieve() is called before build() or load().
ValueError: If query is empty or whitespace-only.
ValueError: If top_k is not a positive integer.
Example:
-------
>>> results = retriever.retrieve("thermal comfort PMV", top_k=5)
>>> for chunk_id, score in results:
... print(f"{chunk_id}: {score:.3f}")
chunk_001: 0.923
chunk_003: 0.756
chunk_002: 0.534
Note:
----
- Query tokenization mirrors document tokenization (lowercase,
no punctuation, whitespace split)
- If query contains no matching terms, results will have score 0.0
- Results are always sorted by score descending (best first)
"""
# =================================================================
# Step 1: Validate index is built
# =================================================================
if not self._is_built():
msg = "BM25 index not built. Call build() or load() first."
raise RuntimeError(msg)
# =================================================================
# Step 2: Validate top_k parameter
# =================================================================
if not isinstance(top_k, int) or top_k <= 0:
msg = f"top_k must be a positive integer, got {top_k}"
raise ValueError(msg)
# =================================================================
# Step 3: Validate and tokenize query
# =================================================================
# Check for empty query before tokenization
if not query or not query.strip():
msg = "query cannot be empty or whitespace-only"
raise ValueError(msg)
# Tokenize query using same process as documents
query_tokens = _tokenize(query)
# Check for empty query after tokenization
# This can happen if query only contains punctuation
if not query_tokens:
msg = "query is empty after tokenization (no valid terms)"
raise ValueError(msg)
# =================================================================
# Step 4: Get BM25 scores for all documents
# =================================================================
# get_scores returns a numpy array with score for each document
# Type narrowing for mypy (we know _bm25 is not None after _is_built check)
assert self._bm25 is not None
assert self._chunk_ids is not None
# Get raw BM25 scores (numpy array)
raw_scores = self._bm25.get_scores(query_tokens)
# =================================================================
# Step 5: Create (index, score) pairs and sort by score descending
# =================================================================
# We need to track indices to map back to chunk_ids
# Convert numpy array to list for processing
indexed_scores: list[tuple[int, float]] = [
(idx, float(score)) for idx, score in enumerate(raw_scores)
]
# Sort by score descending (highest first)
indexed_scores.sort(key=lambda x: x[1], reverse=True)
# =================================================================
# Step 6: Select top-k results
# =================================================================
# Limit to top_k, but don't exceed corpus size
top_k_results = indexed_scores[:top_k]
# =================================================================
# Step 7: Normalize scores to [0, 1] range
# =================================================================
# Extract scores for normalization
scores_only = [score for _, score in top_k_results]
normalized_scores = _normalize_scores(scores_only)
# =================================================================
# Step 8: Build final results with chunk_ids
# =================================================================
# Map indices to chunk_ids and pair with normalized scores
results: list[tuple[str, float]] = [
(self._chunk_ids[idx], norm_score)
for (idx, _), norm_score in zip(
top_k_results, normalized_scores, strict=True
)
]
return results
def save(self, path: Path) -> None:
"""Save the BM25 index to disk for later restoration.
Persists the BM25 index state using pickle. The saved data includes:
- k1, b parameters (for rebuilding BM25Okapi)
- Tokenized corpus (list of token lists)
- Chunk IDs (for result mapping)
Note that the BM25Okapi object itself is not pickled directly because
it may have compatibility issues. Instead, we save the tokenized corpus
and rebuild the BM25Okapi index on load().
Parent directories are created if they don't exist.
Args:
----
path: File path to save the index.
Should typically have .pkl extension.
Parent directories will be created if needed.
Raises:
------
RuntimeError: If save() is called before build().
Example:
-------
>>> retriever = BM25Retriever()
>>> retriever.build(corpus, chunk_ids)
>>> retriever.save(Path("indexes/bm25_index.pkl"))
Note:
----
- The saved file can be restored with BM25Retriever.load()
- Pickle format is used; ensure trusted data sources only
- File size depends on corpus size (tokenized text is stored)
"""
# =================================================================
# Step 1: Validate index is built
# =================================================================
if not self._is_built():
msg = "Cannot save unbuilt BM25 index. Call build() first."
raise RuntimeError(msg)
# =================================================================
# Step 2: Create parent directories if needed
# =================================================================
# This ensures save() works even for nested paths that don't exist
# =================================================================
path.parent.mkdir(parents=True, exist_ok=True)
# =================================================================
# Step 3: Prepare data for persistence
# =================================================================
# We save all the data needed to rebuild the BM25 index:
# - k1, b: BM25 parameters
# - tokenized_corpus: Pre-tokenized documents
# - chunk_ids: Document identifiers
#
# The BM25Okapi object is NOT saved directly because:
# - It may have numpy arrays that complicate pickling
# - Rebuilding from tokenized_corpus is straightforward
# =================================================================
save_data: dict[str, Any] = {
"k1": self._k1,
"b": self._b,
"tokenized_corpus": self._tokenized_corpus,
"chunk_ids": self._chunk_ids,
}
# =================================================================
# Step 4: Write to disk using pickle
# =================================================================
with path.open("wb") as f:
pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)
@classmethod
def load(cls, path: Path) -> BM25Retriever:
"""Load a BM25 index from disk.
Restores a BM25Retriever from a previously saved index file. The
BM25Okapi index is rebuilt from the saved tokenized corpus using
the saved k1 and b parameters.
Args:
----
path: File path to load the index from.
Must be a file created by save().
Returns:
-------
A new BM25Retriever instance with the restored index.
Raises:
------
FileNotFoundError: If the path does not exist.
Example:
-------
>>> retriever = BM25Retriever.load(Path("indexes/bm25_index.pkl"))
>>> results = retriever.retrieve("thermal comfort", top_k=5)
Note:
----
- The returned retriever is immediately usable for retrieval
- The rank_bm25 library is loaded during this operation
- Pickle format is used; only load files from trusted sources
"""
# =================================================================
# Step 1: Validate path exists
# =================================================================
if not path.exists():
msg = f"BM25 index file not found: {path}"
raise FileNotFoundError(msg)
# Step 2: Load saved data from pickle.
# Note: Only load from trusted sources as pickle can execute code.
with path.open("rb") as f:
save_data: dict[str, Any] = pickle.load(f)
# =================================================================
# Step 3: Handle different pickle formats
# =================================================================
# There are two possible formats:
# - Build pipeline format: {"bm25": BM25Okapi, "chunk_ids": list}
# - Retrieval save format: {"k1", "b", "tokenized_corpus", "chunk_ids"}
# =================================================================
chunk_ids: list[str] = save_data["chunk_ids"]
if "bm25" in save_data:
# Format from build pipeline (embeddings/indexing.py)
# The BM25Okapi object is stored directly
bm25_index = save_data["bm25"]
# Extract k1 and b from the loaded BM25 object
k1 = getattr(bm25_index, "k1", 1.5)
b = getattr(bm25_index, "b", 0.75)
# Create retriever with extracted parameters
retriever = cls(k1=k1, b=b)
retriever._bm25 = bm25_index
# Try to get tokenized corpus from the BM25 object
# rank_bm25 stores this internally for scoring
retriever._tokenized_corpus = getattr(bm25_index, "corpus", [])
logger.debug(
"Loaded BM25 from build pipeline format: %d chunks",
len(chunk_ids),
)
else:
# Format from retrieval module's save() method
k1 = save_data["k1"]
b = save_data["b"]
tokenized_corpus: list[list[str]] = save_data["tokenized_corpus"]
# Create new retriever with saved parameters
retriever = cls(k1=k1, b=b)
# Get the BM25 module (lazy load)
bm25_module = _get_bm25_module()
# Rebuild BM25Okapi index from tokenized corpus
retriever._bm25 = bm25_module.BM25Okapi(
corpus=tokenized_corpus,
k1=k1,
b=b,
)
retriever._tokenized_corpus = tokenized_corpus
logger.debug(
"Loaded BM25 from retrieval save format: %d chunks",
len(chunk_ids),
)
# =================================================================
# Step 4: Set chunk_ids
# =================================================================
retriever._chunk_ids = chunk_ids
return retriever