rag-api-node-1 / src /infrastructure /adapters /bge_reranker_adapter.py
Peterase's picture
feat: hybrid RAG pipeline upgrade
daf250b
"""
Reranker Adapter β€” supports BGE-Reranker-v2-m3 AND Jina-Reranker-v3
Auto-detects which model to load based on RERANKER_MODEL setting:
- "BAAI/bge-reranker-v2-m3" β†’ FlagReranker (pointwise cross-encoder)
- "jinaai/jina-reranker-v3" β†’ Jina v3 listwise reranker
Jina v3 advantages over BGE for this project:
- Listwise: sees ALL docs at once β†’ better cross-doc comparison
- 131K context window β†’ reads full Jina-extracted articles (not just 512 chars)
- +9.6% better on English news (BEIR 61.94 vs 56.51)
- Better Arabic ranking (78.69 nDCG)
- Same size (0.6B), same memory, same cost (free, self-hosted)
Thread-safe lazy loading β€” model loads once on first rerank call.
"""
import logging
import threading
from typing import List, Dict, Any, Optional
from src.core.config import settings
from src.core.ports.reranker_port import RerankerPort
logger = logging.getLogger(__name__)
# ── Patch transformers compatibility issue ────────────────────────────────────
try:
import transformers.utils.import_utils as _tui
if not hasattr(_tui, "is_torch_fx_available"):
_tui.is_torch_fx_available = lambda: False
except Exception:
pass
# ── Try FlagEmbedding (for BGE) ───────────────────────────────────────────────
try:
from FlagEmbedding import FlagReranker
HAS_FLAG_RERANKER = True
except ImportError:
HAS_FLAG_RERANKER = False
# ── Try sentence-transformers CrossEncoder (BGE fallback) ────────────────────
try:
from sentence_transformers import CrossEncoder
HAS_CROSS_ENCODER = True
except ImportError:
HAS_CROSS_ENCODER = False
# ── Try transformers (for Jina v3) ────────────────────────────────────────────
try:
import torch
from transformers import AutoModel
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
logger.warning("transformers/torch not available β€” Jina v3 reranker disabled.")
# ═══════════════════════════════════════════════════════════════════════════════
# JINA V3 RERANKER
# ═══════════════════════════════════════════════════════════════════════════════
class JinaV3Reranker:
"""
Jina-Reranker-v3 self-hosted reranker.
Key differences from BGE pointwise:
- Listwise: processes all docs in one forward pass
- 131K context window: reads full articles, not just first 512 chars
- Built on Qwen3-0.6B backbone with causal self-attention
- State-of-the-art BEIR: 61.94 nDCG@10 (vs BGE's 56.51)
Scoring: uses sigmoid(logits) for normalized 0-1 scores.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self._model = None
self._lock = threading.Lock()
self._load_failed = False
self._device = "cpu"
def _load(self):
if self._model is not None or self._load_failed:
return
with self._lock:
if self._model is not None or self._load_failed:
return
if not HAS_TRANSFORMERS:
logger.error("transformers not installed β€” cannot load Jina v3")
self._load_failed = True
return
try:
logger.info(f"Loading Jina v3 reranker: {self.model_name}")
self._device = "cuda" if torch.cuda.is_available() else "cpu"
# Jina v3 uses AutoModel (NOT AutoModelForSequenceClassification)
# It has a built-in .rerank() method that returns relevance_score directly
from transformers import AutoModel
self._model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
dtype="auto",
)
self._model.eval()
logger.info(
f"βœ… Jina v3 reranker loaded on {self._device} "
f"(model={self.model_name})"
)
except Exception as e:
logger.error(f"Failed to load Jina v3 reranker: {e}", exc_info=True)
self._load_failed = True
def compute_scores(
self,
query: str,
docs: List[str],
max_length: int = 1024,
) -> List[float]:
"""
Score all (query, doc) pairs using Jina v3's built-in .rerank() method.
Returns scores in original doc order (not sorted).
"""
if not docs:
return []
self._load()
if self._model is None:
return [0.5] * len(docs)
try:
# Jina v3's .rerank() returns list of dicts:
# [{"document": str, "relevance_score": float, "index": int}, ...]
# Results are sorted by relevance_score descending β€” we need to
# restore original order using the "index" field.
results = self._model.rerank(query, docs)
# Restore original order
scores = [0.0] * len(docs)
for r in results:
original_idx = r["index"]
scores[original_idx] = float(r["relevance_score"])
return scores
except Exception as e:
logger.error(f"Jina v3 rerank() failed: {e}")
return [0.0] * len(docs)
@property
def is_loaded(self) -> bool:
return self._model is not None
# ═══════════════════════════════════════════════════════════════════════════════
# UNIFIED RERANKER ADAPTER
# ═══════════════════════════════════════════════════════════════════════════════
class BgeRerankerAdapter(RerankerPort):
"""
Unified reranker adapter β€” auto-selects BGE or Jina v3 based on config.
RERANKER_MODEL=jinaai/jina-reranker-v3 β†’ Jina v3 (recommended)
RERANKER_MODEL=BAAI/bge-reranker-v2-m3 β†’ BGE (legacy)
Both are self-hosted, free, ~0.6B parameters, ~1.2GB on disk.
"""
# Max content chars to send to reranker
# Jina v3: 1024 tokens β‰ˆ 4096 chars β€” reads much more than BGE's 512 chars
MAX_CONTENT_CHARS_JINA = 4096
MAX_CONTENT_CHARS_BGE = 512
def __init__(self):
self.model_name = settings.RERANKER_MODEL
self._is_jina_v3 = "jina-reranker-v3" in self.model_name.lower()
self._lock = threading.Lock()
self._load_failed = False
# Check if Jina API reranker is enabled (takes priority over self-hosted)
self._jina_api = None
if getattr(settings, 'JINA_RERANKER_ENABLED', False) and getattr(settings, 'JINA_API_KEY', ''):
try:
from src.infrastructure.adapters.jina_reranker_adapter import JinaRerankerAPIAdapter
jina_key = settings.JINA_API_KEY
if jina_key and jina_key not in ("", "your-jina-api-key-here"):
self._jina_api = JinaRerankerAPIAdapter(
api_key=jina_key,
model=getattr(settings, 'JINA_RERANKER_MODEL', 'jina-reranker-v3'),
timeout=getattr(settings, 'JINA_RERANKER_TIMEOUT', 5.0),
)
logger.info("Reranker configured: Jina API (cloud, fast)")
except Exception as e:
logger.warning(f"Jina API reranker init failed: {e}")
# Jina v3 self-hosted path
if self._is_jina_v3 and not self._jina_api:
self._jina = JinaV3Reranker(self.model_name)
self._bge_model = None
self._use_flag = False
logger.info(f"Reranker configured: Jina v3 self-hosted ({self.model_name})")
elif not self._jina_api:
# BGE path
self._jina = None
self._bge_model = None
self._use_flag = False
logger.info(f"Reranker configured: BGE ({self.model_name})")
else:
self._jina = None
self._bge_model = None
self._use_flag = False
def _load_bge(self):
"""Lazy-load BGE reranker (thread-safe)."""
if self._bge_model is not None or self._load_failed:
return
with self._lock:
if self._bge_model is not None or self._load_failed:
return
logger.info(f"Loading BGE reranker: {self.model_name}")
try:
if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower():
# Patch XLMRobertaTokenizer for older transformers versions
try:
from transformers import XLMRobertaTokenizer, PreTrainedTokenizer
for method_name in [
"prepare_for_model",
"build_inputs_with_special_tokens",
"create_token_type_ids_from_sequences",
"get_special_tokens_mask",
"convert_tokens_to_string",
]:
if not hasattr(XLMRobertaTokenizer, method_name):
base_method = getattr(PreTrainedTokenizer, method_name, None)
if base_method:
setattr(XLMRobertaTokenizer, method_name, base_method)
except Exception as patch_err:
logger.debug(f"Tokenizer patch skipped: {patch_err}")
self._bge_model = FlagReranker(
self.model_name,
use_fp16=True,
normalize=True,
trust_remote_code=True,
)
self._use_flag = True
logger.info(f"βœ… BGE loaded via FlagReranker (fp16, multilingual)")
elif HAS_CROSS_ENCODER:
self._bge_model = CrossEncoder(self.model_name)
self._use_flag = False
logger.info(f"βœ… BGE loaded via CrossEncoder (fallback)")
else:
logger.error("No BGE backend available (FlagEmbedding or sentence-transformers required)")
self._load_failed = True
except Exception as e:
logger.error(f"Failed to load BGE reranker '{self.model_name}': {e}", exc_info=True)
self._load_failed = True
# ── Public interface ──────────────────────────────────────────────────────
def rerank(
self,
query: str,
docs: List[Dict[str, Any]],
top_n: int = 5,
) -> List[Dict[str, Any]]:
"""
Rerank documents by relevance to query.
Priority: Jina API (cloud) > Jina v3 self-hosted > BGE
Jina v3 path: uses full article content (up to 4096 chars)
BGE path: uses first 512 chars only
Returns top_n docs sorted by rerank_score descending.
"""
if not docs:
return []
# Priority: Jina API > Jina v3 self-hosted > BGE
if self._jina_api and self._jina_api.is_available():
return self._jina_api.rerank(query, docs, top_n)
elif self._is_jina_v3 and self._jina:
return self._rerank_jina(query, docs, top_n)
else:
return self._rerank_bge(query, docs, top_n)
def _rerank_jina(
self,
query: str,
docs: List[Dict[str, Any]],
top_n: int,
) -> List[Dict[str, Any]]:
"""Rerank using Jina v3 β€” reads full article content."""
# Ensure model is loaded
self._jina._load()
if self._jina._load_failed or not self._jina.is_loaded:
logger.warning("Jina v3 unavailable β€” falling back to vector score ordering")
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
# Build content list β€” use full content up to 4096 chars
# This is the key advantage: Jina reads 8x more content than BGE
valid_docs = []
doc_texts = []
for doc in docs:
content = doc.get("content", "").strip()
if content:
doc_texts.append(content[:self.MAX_CONTENT_CHARS_JINA])
valid_docs.append(doc)
if not doc_texts:
return []
try:
scores = self._jina.compute_scores(query, doc_texts)
for i, doc in enumerate(valid_docs):
doc["rerank_score"] = scores[i]
valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)
logger.info(
f"[Reranker] Jina v3: {len(valid_docs)} docs β†’ top {top_n} "
f"(max_score={valid_docs[0]['rerank_score']:.3f})"
)
return valid_docs[:top_n]
except Exception as e:
logger.error(f"Jina v3 reranking failed: {e} β€” falling back to vector score")
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
def _rerank_bge(
self,
query: str,
docs: List[Dict[str, Any]],
top_n: int,
) -> List[Dict[str, Any]]:
"""Rerank using BGE β€” reads first 512 chars only."""
if self._bge_model is None:
self._load_bge()
if self._bge_model is None:
logger.warning("BGE unavailable β€” falling back to vector score ordering")
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
pairs = []
valid_docs = []
for doc in docs:
content = doc.get("content", "").strip()
if content:
pairs.append([query, content[:self.MAX_CONTENT_CHARS_BGE]])
valid_docs.append(doc)
if not pairs:
return []
try:
if self._use_flag:
scores = self._bge_model.compute_score(pairs, batch_size=64)
if isinstance(scores, float):
scores = [scores]
else:
scores = self._bge_model.predict(pairs)
if isinstance(scores, float):
scores = [scores]
for i, doc in enumerate(valid_docs):
doc["rerank_score"] = float(scores[i])
valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)
logger.info(
f"[Reranker] BGE: {len(valid_docs)} docs β†’ top {top_n} "
f"(max_score={valid_docs[0]['rerank_score']:.3f})"
)
return valid_docs[:top_n]
except Exception as e:
logger.error(f"BGE reranking failed: {e} β€” falling back to vector score")
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
@property
def model_type(self) -> str:
return "jina_v3" if self._is_jina_v3 else "bge"