sadickam's picture
Prepare for HF Space deployment
d01a7e3
"""BGE embedding encoder for text chunks.
This module provides the BGEEncoder class for generating high-quality
embeddings from text using BAAI General Embedding (BGE) models. The
encoder is optimized for:
- Batch processing for efficiency
- Float16 output for memory savings
- GPU acceleration when available
Lazy Loading:
torch and sentence-transformers are loaded on first use to avoid
import overhead when embeddings are not needed. This is critical
for the serve pipeline where embeddings may not be needed (using
prebuilt FAISS index instead).
Design Decisions:
- The model is loaded lazily on first encode() call, not in __init__
- Text normalization is applied before encoding to fix OCR artifacts
- Progress callbacks enable integration with CLI progress bars
- Float16 output reduces memory usage by 50% with minimal quality loss
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
import numpy as np
# =============================================================================
# Type Checking Imports
# =============================================================================
# These imports are only processed by type checkers (mypy, pyright) and IDEs.
# At runtime, we use lazy imports inside methods to avoid loading heavy
# dependencies (torch, sentence-transformers) until actually needed.
# =============================================================================
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from numpy.typing import NDArray
# =============================================================================
# Module Exports
# =============================================================================
__all__: list[str] = ["BGEEncoder"]
# =============================================================================
# Constants
# =============================================================================
# Default model for embedding generation.
# BAAI/bge-small-en-v1.5 provides 384-dimensional embeddings with excellent
# quality for retrieval tasks. It's small enough for CPU inference while
# maintaining competitive performance with larger models.
_DEFAULT_MODEL_NAME: str = "BAAI/bge-small-en-v1.5"
# Embedding dimension for the default model.
# This is the output dimension for bge-small-en-v1.5.
# Other BGE models have different dimensions:
# - bge-small-en-v1.5: 384
# - bge-base-en-v1.5: 768
# - bge-large-en-v1.5: 1024
_BGE_SMALL_EMBEDDING_DIM: int = 384
class BGEEncoder:
"""Generate embeddings using BGE models from BAAI.
This class provides methods for encoding text into dense vector
embeddings using BAAI General Embedding (BGE) models. The encoder
handles:
- Model loading and initialization (lazy, on first use)
- Batch processing for memory-efficient encoding
- GPU/CPU device management with auto-detection
- Float16 conversion for storage efficiency
- Text normalization to fix OCR and extraction artifacts
The default model is 'BAAI/bge-small-en-v1.5' which provides a
good balance of quality and speed for English text. It produces
384-dimensional embeddings.
Lazy Loading Pattern:
The encoder uses lazy loading to avoid importing torch and
sentence-transformers until the first encode() call. This is
important for:
- Fast startup times when embeddings aren't needed
- Reduced memory footprint in serve mode
- Compatibility with environments without GPU support
Attributes:
----------
model_name : str
Name of the BGE model being used (HuggingFace identifier).
device : str
Device used for inference ('cuda' or 'cpu').
embedding_dim : int
Dimension of output embeddings (384 for bge-small-en-v1.5).
Example:
-------
>>> encoder = BGEEncoder()
>>> embeddings = encoder.encode(["Hello world", "Test text"])
>>> print(embeddings.shape)
(2, 384)
>>> print(embeddings.dtype)
float16
Note:
----
The encoder implements the Encoder protocol defined in
rag_chatbot.embeddings.models, enabling dependency injection
and testing with mock encoders.
"""
def __init__(
self,
model_name: str = _DEFAULT_MODEL_NAME,
device: str | None = None,
normalize_text: bool = True,
) -> None:
"""Initialize the BGE encoder with configuration.
The model is NOT loaded during initialization to support lazy
loading. The actual model loading happens on the first call to
encode(). This design allows the encoder to be instantiated
quickly without loading heavy dependencies.
Args:
----
model_name : str
HuggingFace model identifier for the BGE model.
Defaults to 'BAAI/bge-small-en-v1.5'.
device : str | None
Device to use for inference. Options:
- 'cuda': Use GPU (requires CUDA-enabled torch)
- 'cpu': Use CPU only
- None: Auto-detect (use CUDA if available, else CPU)
normalize_text : bool
If True, apply text normalization before encoding to fix
common PDF extraction artifacts like jumbled words, extra
spaces, and ALL CAPS text. Defaults to True.
Example:
-------
>>> # Auto-detect device (recommended)
>>> encoder = BGEEncoder()
>>> # Force CPU usage
>>> encoder = BGEEncoder(device="cpu")
>>> # Use a different BGE model
>>> encoder = BGEEncoder(model_name="BAAI/bge-base-en-v1.5")
>>> # Disable text normalization (for pre-processed text)
>>> encoder = BGEEncoder(normalize_text=False)
"""
# Store configuration for lazy initialization
# These values are used when the model is loaded on first encode()
self._model_name: str = model_name
self._requested_device: str | None = device
self._normalize_text: bool = normalize_text
# Model and device will be initialized lazily on first encode() call
# Using None as sentinel to indicate uninitialized state
self._model: object | None = None # SentenceTransformer instance
self._device: str | None = None # Actual device being used
self._normalizer: object | None = None # TextNormalizer instance
def _ensure_model_loaded(self) -> None:
"""Load the model if not already loaded (lazy initialization).
This method is called internally before any encoding operation.
It handles:
1. Importing torch and sentence-transformers
2. Detecting the appropriate device (CUDA or CPU)
3. Loading the SentenceTransformer model
4. Creating the TextNormalizer instance
The lazy loading pattern means these expensive operations only
happen when actually needed, not at import time.
Raises:
------
ImportError
If torch or sentence-transformers are not installed.
RuntimeError
If model loading fails for any reason.
Note:
----
This method is idempotent - calling it multiple times is safe
and will only load the model once.
"""
# Skip if already initialized (model is not None)
if self._model is not None:
return
# =================================================================
# Step 1: Import heavy dependencies lazily
# =================================================================
# These imports are placed inside the method to avoid loading
# torch (500MB+) and sentence-transformers at module import time.
# This is crucial for fast startup in the serve pipeline.
# =================================================================
import torch
from sentence_transformers import SentenceTransformer
# =================================================================
# Step 2: Determine the device to use
# =================================================================
# If device was explicitly specified, use that.
# Otherwise, auto-detect CUDA availability.
# =================================================================
if self._requested_device is not None:
# User explicitly specified a device
self._device = self._requested_device
else:
# Auto-detect: prefer CUDA if available
# torch.cuda.is_available() returns True if CUDA is properly
# installed and at least one GPU is available
self._device = "cuda" if torch.cuda.is_available() else "cpu"
# =================================================================
# Step 3: Load the SentenceTransformer model
# =================================================================
# SentenceTransformer handles downloading and caching the model
# from HuggingFace Hub. The model is loaded onto the specified
# device for efficient inference.
# =================================================================
self._model = SentenceTransformer(
model_name_or_path=self._model_name,
device=self._device,
)
# =================================================================
# Step 4: Initialize the text normalizer
# =================================================================
# TextNormalizer fixes common PDF extraction artifacts.
# Import here to maintain lazy loading pattern.
# =================================================================
if self._normalize_text:
from rag_chatbot.chunking.models import TextNormalizer
self._normalizer = TextNormalizer()
def encode(
self,
texts: Sequence[str],
batch_size: int = 32,
show_progress: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
) -> NDArray[np.float16]:
"""Encode texts into embedding vectors.
Transforms a sequence of text strings into dense vector embeddings.
The encoding is done in batches to manage memory usage, especially
important for large datasets.
Processing Steps:
1. Load model if not already loaded (lazy initialization)
2. Normalize text if enabled (fix OCR artifacts)
3. Encode in batches using SentenceTransformer
4. Convert to float16 for memory efficiency
5. Call progress callback after each batch
Args:
----
texts : Sequence[str]
Sequence of text strings to encode. Each string should be
a document or chunk to be embedded.
batch_size : int
Number of texts to process in each batch. Larger batches
are faster but use more memory. Default is 32, which works
well for most GPU memory configurations.
show_progress : bool
Whether to show a progress bar during encoding. Passed to
SentenceTransformer.encode(). Default is False.
progress_callback : Callable[[int, int], None] | None
Optional callback function called after each batch.
Receives (current_batch_index, total_batches) as arguments.
Useful for integrating with custom progress indicators.
Returns:
-------
NDArray[np.float16]
NumPy array of shape (len(texts), embedding_dim) with
float16 dtype. Each row is the embedding for the
corresponding input text.
Raises:
------
ValueError
If texts is empty.
RuntimeError
If encoding fails due to model issues.
Example:
-------
>>> encoder = BGEEncoder()
>>> texts = ["Hello world", "Thermal comfort is important"]
>>> embeddings = encoder.encode(texts)
>>> print(embeddings.shape)
(2, 384)
>>> print(embeddings.dtype)
float16
>>> # With progress callback
>>> def on_progress(current, total):
... print(f"Batch {current}/{total}")
>>> embeddings = encoder.encode(texts, progress_callback=on_progress)
Batch 1/1
Note:
----
The returned float16 dtype reduces memory usage by 50% compared
to float32, with negligible impact on retrieval quality for
most applications.
"""
# =================================================================
# Step 1: Handle empty input
# =================================================================
# Return empty array with correct shape for empty input
# This avoids errors in downstream processing
# =================================================================
if len(texts) == 0:
return np.empty((0, self.embedding_dim), dtype=np.float16)
# =================================================================
# Step 2: Ensure model is loaded (lazy initialization)
# =================================================================
self._ensure_model_loaded()
# =================================================================
# Step 3: Normalize text if enabled
# =================================================================
# TextNormalizer fixes common PDF extraction artifacts:
# - Jumbled words: "ther mal" -> "thermal"
# - Extra spaces: "the text" -> "the text"
# - ALL CAPS: Applied for headings via is_heading flag
# For embeddings, we use is_heading=False (regular text mode)
# =================================================================
if self._normalize_text and self._normalizer is not None:
# Import TextNormalizer type for proper method access
from rag_chatbot.chunking.models import TextNormalizer
# Cast to TextNormalizer for type checker
normalizer: TextNormalizer = self._normalizer # type: ignore[assignment]
# Normalize each text, using is_heading=False for body text
processed_texts: list[str] = [
normalizer.normalize(text, is_heading=False) for text in texts
]
else:
# No normalization, convert to list for consistent processing
processed_texts = list(texts)
# =================================================================
# Step 4: Calculate batch information for progress tracking
# =================================================================
# We need to know total batches upfront for the progress callback
# math.ceil ensures we count partial final batches
# =================================================================
total_batches: int = math.ceil(len(processed_texts) / batch_size)
# =================================================================
# Step 5: Encode in batches with progress tracking
# =================================================================
# If a progress callback is provided, we need to encode in manual
# batches to report progress. Otherwise, let SentenceTransformer
# handle batching internally for optimal performance.
# =================================================================
if progress_callback is not None:
# Manual batching with progress callback
all_embeddings: list[NDArray[np.float32]] = []
for batch_idx in range(total_batches):
# Calculate batch slice indices
start_idx: int = batch_idx * batch_size
end_idx: int = min(start_idx + batch_size, len(processed_texts))
batch_texts: list[str] = processed_texts[start_idx:end_idx]
# Encode this batch
# convert_to_numpy=True returns ndarray instead of tensor
# normalize_embeddings=True applies L2 normalization (standard for BGE)
batch_embeddings: NDArray[np.float32] = self._model.encode( # type: ignore[union-attr]
sentences=batch_texts,
batch_size=batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
normalize_embeddings=True,
)
all_embeddings.append(batch_embeddings)
# Report progress to callback
# Batch indices are 1-based for human readability
progress_callback(batch_idx + 1, total_batches)
# Concatenate all batches into single array
embeddings: NDArray[np.float32] = np.concatenate(all_embeddings, axis=0)
else:
# Let SentenceTransformer handle batching internally
# This is more efficient when no progress callback is needed
embeddings = self._model.encode( # type: ignore[union-attr]
sentences=processed_texts,
batch_size=batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
normalize_embeddings=True,
)
# =================================================================
# Step 6: Convert to float16 for memory efficiency
# =================================================================
# Float16 uses half the memory of float32 with minimal quality loss.
# This is especially important when storing large embedding datasets.
# The conversion is done after encoding because many models compute
# in float32 internally for numerical stability.
# =================================================================
embeddings_float16: NDArray[np.float16] = embeddings.astype(np.float16)
return embeddings_float16
@property
def embedding_dim(self) -> int:
"""Get the dimension of embeddings produced by this encoder.
For BAAI/bge-small-en-v1.5, this returns 384.
The dimension is determined by the model architecture and is
constant for a given model. It's used for:
- Initializing FAISS index dimensions
- Validating embedding arrays
- Pre-allocating storage
Returns:
-------
int
Integer dimension of output embeddings (384 for bge-small).
Example:
-------
>>> encoder = BGEEncoder()
>>> encoder.embedding_dim
384
Note:
----
This property returns a constant value and does not load the
model. The dimension is known from the model specification.
"""
# For bge-small-en-v1.5, the embedding dimension is 384
# This is a constant property that doesn't require loading the model
# Different BGE models have different dimensions:
# - bge-small: 384
# - bge-base: 768
# - bge-large: 1024
if self._model_name == _DEFAULT_MODEL_NAME:
return _BGE_SMALL_EMBEDDING_DIM
# For other models, we need to load the model to get the dimension
# This is a fallback for custom model configurations
self._ensure_model_loaded()
return self._model.get_sentence_embedding_dimension() # type: ignore[union-attr, no-any-return]
@property
def model_name(self) -> str:
"""Get the name of the embedding model.
Returns the HuggingFace model identifier being used for
embedding generation. This is useful for logging, debugging,
and ensuring consistency across pipeline stages.
Returns:
-------
str
String model name or identifier (e.g., 'BAAI/bge-small-en-v1.5').
Example:
-------
>>> encoder = BGEEncoder()
>>> encoder.model_name
'BAAI/bge-small-en-v1.5'
"""
return self._model_name
@property
def device(self) -> str:
"""Get the device being used for inference.
Returns the device (CPU or CUDA) that the model is running on.
If the model hasn't been loaded yet, this will trigger lazy
loading to determine the actual device.
Returns:
-------
str
Device identifier ('cpu' or 'cuda').
Example:
-------
>>> encoder = BGEEncoder()
>>> encoder.device # Triggers model loading if not loaded
'cpu' # or 'cuda' if GPU is available
Note:
----
Accessing this property will trigger model loading if the
model hasn't been loaded yet, since device detection happens
during model initialization.
"""
# If device is not yet determined, load the model to detect it
if self._device is None:
self._ensure_model_loaded()
# At this point _device is guaranteed to be set
return self._device # type: ignore[return-value]