rag-api-node-1 / src /infrastructure /adapters /bge_embedder_adapter.py
Peterase's picture
feat(rag): implement hybrid search with live sources and production-grade intent classification
a63c61f
import logging
import os
from typing import Dict, Any, List
from src.core.config import settings
from src.core.ports.embedder_port import EmbedderPort
logger = logging.getLogger(__name__)
if os.name == 'nt':
os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
try:
import transformers.utils.import_utils
if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'):
transformers.utils.import_utils.is_torch_fx_available = lambda: False
from FlagEmbedding import BGEM3FlagModel
HAS_FLAG_EMBEDDING = True
except ImportError as e:
HAS_FLAG_EMBEDDING = False
logger.warning(f"FlagEmbedding not installed: {e}. Using dummy embeddings.")
class BgeEmbedderAdapter(EmbedderPort):
def __init__(self):
self.model = None
self.model_name = settings.EMBEDDING_MODEL
def _load_model(self):
if self.model is None:
if not HAS_FLAG_EMBEDDING:
logger.warning("FlagEmbedding not installed. Using dummy embeddings.")
return
logger.info(f"Loading embedding model: {self.model_name}")
try:
self.model = BGEM3FlagModel(self.model_name, use_fp16=True)
logger.info(f"Successfully loaded {self.model_name} (Hybrid Mode)")
except Exception as e:
logger.error(f"Failed to load embedding model: {e}", exc_info=True)
raise e
def encode_query(self, text: str) -> Dict[str, Any]:
"""Encodes a query string into dense and sparse vectors."""
if self.model is None:
self._load_model()
if not HAS_FLAG_EMBEDDING or self.model is None:
return {
"dense": [0.1] * settings.VECTOR_SIZE,
"sparse": None
}
embeddings = self.model.encode(
sentences=[text],
batch_size=1,
max_length=512,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False
)
dense_vec = embeddings['dense_vecs'][0].tolist()
lexical_dict = embeddings['lexical_weights'][0]
sparse_vec = {
"indices": [int(k) for k in lexical_dict.keys()],
"values": [float(v) for v in lexical_dict.values()]
}
return {
"dense": dense_vec,
"sparse": sparse_vec
}
def encode_sparse_only(self, text: str) -> Dict[str, Any]:
"""
Encodes only the sparse (BM25/lexical) vector for a single query.
Skips dense computation β€” ~2x faster than encode_query.
Used for per-language sparse queries when the dense vector is
already available from the English query.
"""
if self.model is None:
self._load_model()
if not HAS_FLAG_EMBEDDING or self.model is None:
return {"sparse": None}
embeddings = self.model.encode(
sentences=[text],
batch_size=1,
max_length=512,
return_dense=False, # skip dense β€” saves ~60% compute
return_sparse=True,
return_colbert_vecs=False
)
lexical_dict = embeddings['lexical_weights'][0]
sparse_vec = {
"indices": [int(k) for k in lexical_dict.keys()],
"values": [float(v) for v in lexical_dict.values()]
}
return {"sparse": sparse_vec}
def encode_sparse_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
Encode multiple texts into sparse vectors in a SINGLE model forward pass.
Why this matters:
BGE-M3 holds the Python GIL during inference β€” ThreadPoolExecutor gives
zero benefit for CPU-bound model calls. Calling encode_sparse_only() 6
times in a thread pool still runs sequentially. This method batches all
6 language queries into one model.encode() call, which is ~5x faster
than 6 sequential calls because:
- One tokenization pass for all texts
- One forward pass through the transformer
- GPU/CPU utilisation is much higher with batch_size=6 vs batch_size=1
Returns a list of sparse dicts in the same order as `texts`.
Falls back to empty sparse vectors on failure.
"""
if not texts:
return []
if self.model is None:
self._load_model()
if not HAS_FLAG_EMBEDDING or self.model is None:
return [{"sparse": None} for _ in texts]
try:
embeddings = self.model.encode(
sentences=texts,
batch_size=len(texts), # all in one shot
max_length=512,
return_dense=False, # skip dense β€” not needed here
return_sparse=True,
return_colbert_vecs=False
)
results = []
for lexical_dict in embeddings['lexical_weights']:
results.append({
"sparse": {
"indices": [int(k) for k in lexical_dict.keys()],
"values": [float(v) for v in lexical_dict.values()],
}
})
return results
except Exception as e:
logger.error(f"encode_sparse_batch failed: {e} β€” returning empty sparse vectors")
return [{"sparse": None} for _ in texts]