Spaces:
Running
Running
| """ | |
| Reranker Adapter β supports BGE-Reranker-v2-m3 AND Jina-Reranker-v3 | |
| Auto-detects which model to load based on RERANKER_MODEL setting: | |
| - "BAAI/bge-reranker-v2-m3" β FlagReranker (pointwise cross-encoder) | |
| - "jinaai/jina-reranker-v3" β Jina v3 listwise reranker | |
| Jina v3 advantages over BGE for this project: | |
| - Listwise: sees ALL docs at once β better cross-doc comparison | |
| - 131K context window β reads full Jina-extracted articles (not just 512 chars) | |
| - +9.6% better on English news (BEIR 61.94 vs 56.51) | |
| - Better Arabic ranking (78.69 nDCG) | |
| - Same size (0.6B), same memory, same cost (free, self-hosted) | |
| Thread-safe lazy loading β model loads once on first rerank call. | |
| """ | |
| import logging | |
| import threading | |
| from typing import List, Dict, Any, Optional | |
| from src.core.config import settings | |
| from src.core.ports.reranker_port import RerankerPort | |
| logger = logging.getLogger(__name__) | |
| # ββ Patch transformers compatibility issue ββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import transformers.utils.import_utils as _tui | |
| if not hasattr(_tui, "is_torch_fx_available"): | |
| _tui.is_torch_fx_available = lambda: False | |
| except Exception: | |
| pass | |
| # ββ Try FlagEmbedding (for BGE) βββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from FlagEmbedding import FlagReranker | |
| HAS_FLAG_RERANKER = True | |
| except ImportError: | |
| HAS_FLAG_RERANKER = False | |
| # ββ Try sentence-transformers CrossEncoder (BGE fallback) ββββββββββββββββββββ | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| HAS_CROSS_ENCODER = True | |
| except ImportError: | |
| HAS_CROSS_ENCODER = False | |
| # ββ Try transformers (for Jina v3) ββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import torch | |
| from transformers import AutoModel | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| logger.warning("transformers/torch not available β Jina v3 reranker disabled.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JINA V3 RERANKER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class JinaV3Reranker: | |
| """ | |
| Jina-Reranker-v3 self-hosted reranker. | |
| Key differences from BGE pointwise: | |
| - Listwise: processes all docs in one forward pass | |
| - 131K context window: reads full articles, not just first 512 chars | |
| - Built on Qwen3-0.6B backbone with causal self-attention | |
| - State-of-the-art BEIR: 61.94 nDCG@10 (vs BGE's 56.51) | |
| Scoring: uses sigmoid(logits) for normalized 0-1 scores. | |
| """ | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| self._model = None | |
| self._lock = threading.Lock() | |
| self._load_failed = False | |
| self._device = "cpu" | |
| def _load(self): | |
| if self._model is not None or self._load_failed: | |
| return | |
| with self._lock: | |
| if self._model is not None or self._load_failed: | |
| return | |
| if not HAS_TRANSFORMERS: | |
| logger.error("transformers not installed β cannot load Jina v3") | |
| self._load_failed = True | |
| return | |
| try: | |
| logger.info(f"Loading Jina v3 reranker: {self.model_name}") | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Jina v3 uses AutoModel (NOT AutoModelForSequenceClassification) | |
| # It has a built-in .rerank() method that returns relevance_score directly | |
| from transformers import AutoModel | |
| self._model = AutoModel.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| dtype="auto", | |
| ) | |
| self._model.eval() | |
| logger.info( | |
| f"β Jina v3 reranker loaded on {self._device} " | |
| f"(model={self.model_name})" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load Jina v3 reranker: {e}", exc_info=True) | |
| self._load_failed = True | |
| def compute_scores( | |
| self, | |
| query: str, | |
| docs: List[str], | |
| max_length: int = 1024, | |
| ) -> List[float]: | |
| """ | |
| Score all (query, doc) pairs using Jina v3's built-in .rerank() method. | |
| Returns scores in original doc order (not sorted). | |
| """ | |
| if not docs: | |
| return [] | |
| self._load() | |
| if self._model is None: | |
| return [0.5] * len(docs) | |
| try: | |
| # Jina v3's .rerank() returns list of dicts: | |
| # [{"document": str, "relevance_score": float, "index": int}, ...] | |
| # Results are sorted by relevance_score descending β we need to | |
| # restore original order using the "index" field. | |
| results = self._model.rerank(query, docs) | |
| # Restore original order | |
| scores = [0.0] * len(docs) | |
| for r in results: | |
| original_idx = r["index"] | |
| scores[original_idx] = float(r["relevance_score"]) | |
| return scores | |
| except Exception as e: | |
| logger.error(f"Jina v3 rerank() failed: {e}") | |
| return [0.0] * len(docs) | |
| def is_loaded(self) -> bool: | |
| return self._model is not None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UNIFIED RERANKER ADAPTER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BgeRerankerAdapter(RerankerPort): | |
| """ | |
| Unified reranker adapter β auto-selects BGE or Jina v3 based on config. | |
| RERANKER_MODEL=jinaai/jina-reranker-v3 β Jina v3 (recommended) | |
| RERANKER_MODEL=BAAI/bge-reranker-v2-m3 β BGE (legacy) | |
| Both are self-hosted, free, ~0.6B parameters, ~1.2GB on disk. | |
| """ | |
| # Max content chars to send to reranker | |
| # Jina v3: 1024 tokens β 4096 chars β reads much more than BGE's 512 chars | |
| MAX_CONTENT_CHARS_JINA = 4096 | |
| MAX_CONTENT_CHARS_BGE = 512 | |
| def __init__(self): | |
| self.model_name = settings.RERANKER_MODEL | |
| self._is_jina_v3 = "jina-reranker-v3" in self.model_name.lower() | |
| self._lock = threading.Lock() | |
| self._load_failed = False | |
| # Check if Jina API reranker is enabled (takes priority over self-hosted) | |
| self._jina_api = None | |
| if getattr(settings, 'JINA_RERANKER_ENABLED', False) and getattr(settings, 'JINA_API_KEY', ''): | |
| try: | |
| from src.infrastructure.adapters.jina_reranker_adapter import JinaRerankerAPIAdapter | |
| jina_key = settings.JINA_API_KEY | |
| if jina_key and jina_key not in ("", "your-jina-api-key-here"): | |
| self._jina_api = JinaRerankerAPIAdapter( | |
| api_key=jina_key, | |
| model=getattr(settings, 'JINA_RERANKER_MODEL', 'jina-reranker-v3'), | |
| timeout=getattr(settings, 'JINA_RERANKER_TIMEOUT', 5.0), | |
| ) | |
| logger.info("Reranker configured: Jina API (cloud, fast)") | |
| except Exception as e: | |
| logger.warning(f"Jina API reranker init failed: {e}") | |
| # Jina v3 self-hosted path | |
| if self._is_jina_v3 and not self._jina_api: | |
| self._jina = JinaV3Reranker(self.model_name) | |
| self._bge_model = None | |
| self._use_flag = False | |
| logger.info(f"Reranker configured: Jina v3 self-hosted ({self.model_name})") | |
| elif not self._jina_api: | |
| # BGE path | |
| self._jina = None | |
| self._bge_model = None | |
| self._use_flag = False | |
| logger.info(f"Reranker configured: BGE ({self.model_name})") | |
| else: | |
| self._jina = None | |
| self._bge_model = None | |
| self._use_flag = False | |
| def _load_bge(self): | |
| """Lazy-load BGE reranker (thread-safe).""" | |
| if self._bge_model is not None or self._load_failed: | |
| return | |
| with self._lock: | |
| if self._bge_model is not None or self._load_failed: | |
| return | |
| logger.info(f"Loading BGE reranker: {self.model_name}") | |
| try: | |
| if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower(): | |
| # Patch XLMRobertaTokenizer for older transformers versions | |
| try: | |
| from transformers import XLMRobertaTokenizer, PreTrainedTokenizer | |
| for method_name in [ | |
| "prepare_for_model", | |
| "build_inputs_with_special_tokens", | |
| "create_token_type_ids_from_sequences", | |
| "get_special_tokens_mask", | |
| "convert_tokens_to_string", | |
| ]: | |
| if not hasattr(XLMRobertaTokenizer, method_name): | |
| base_method = getattr(PreTrainedTokenizer, method_name, None) | |
| if base_method: | |
| setattr(XLMRobertaTokenizer, method_name, base_method) | |
| except Exception as patch_err: | |
| logger.debug(f"Tokenizer patch skipped: {patch_err}") | |
| self._bge_model = FlagReranker( | |
| self.model_name, | |
| use_fp16=True, | |
| normalize=True, | |
| trust_remote_code=True, | |
| ) | |
| self._use_flag = True | |
| logger.info(f"β BGE loaded via FlagReranker (fp16, multilingual)") | |
| elif HAS_CROSS_ENCODER: | |
| self._bge_model = CrossEncoder(self.model_name) | |
| self._use_flag = False | |
| logger.info(f"β BGE loaded via CrossEncoder (fallback)") | |
| else: | |
| logger.error("No BGE backend available (FlagEmbedding or sentence-transformers required)") | |
| self._load_failed = True | |
| except Exception as e: | |
| logger.error(f"Failed to load BGE reranker '{self.model_name}': {e}", exc_info=True) | |
| self._load_failed = True | |
| # ββ Public interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rerank( | |
| self, | |
| query: str, | |
| docs: List[Dict[str, Any]], | |
| top_n: int = 5, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Rerank documents by relevance to query. | |
| Priority: Jina API (cloud) > Jina v3 self-hosted > BGE | |
| Jina v3 path: uses full article content (up to 4096 chars) | |
| BGE path: uses first 512 chars only | |
| Returns top_n docs sorted by rerank_score descending. | |
| """ | |
| if not docs: | |
| return [] | |
| # Priority: Jina API > Jina v3 self-hosted > BGE | |
| if self._jina_api and self._jina_api.is_available(): | |
| return self._jina_api.rerank(query, docs, top_n) | |
| elif self._is_jina_v3 and self._jina: | |
| return self._rerank_jina(query, docs, top_n) | |
| else: | |
| return self._rerank_bge(query, docs, top_n) | |
| def _rerank_jina( | |
| self, | |
| query: str, | |
| docs: List[Dict[str, Any]], | |
| top_n: int, | |
| ) -> List[Dict[str, Any]]: | |
| """Rerank using Jina v3 β reads full article content.""" | |
| # Ensure model is loaded | |
| self._jina._load() | |
| if self._jina._load_failed or not self._jina.is_loaded: | |
| logger.warning("Jina v3 unavailable β falling back to vector score ordering") | |
| return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] | |
| # Build content list β use full content up to 4096 chars | |
| # This is the key advantage: Jina reads 8x more content than BGE | |
| valid_docs = [] | |
| doc_texts = [] | |
| for doc in docs: | |
| content = doc.get("content", "").strip() | |
| if content: | |
| doc_texts.append(content[:self.MAX_CONTENT_CHARS_JINA]) | |
| valid_docs.append(doc) | |
| if not doc_texts: | |
| return [] | |
| try: | |
| scores = self._jina.compute_scores(query, doc_texts) | |
| for i, doc in enumerate(valid_docs): | |
| doc["rerank_score"] = scores[i] | |
| valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True) | |
| logger.info( | |
| f"[Reranker] Jina v3: {len(valid_docs)} docs β top {top_n} " | |
| f"(max_score={valid_docs[0]['rerank_score']:.3f})" | |
| ) | |
| return valid_docs[:top_n] | |
| except Exception as e: | |
| logger.error(f"Jina v3 reranking failed: {e} β falling back to vector score") | |
| return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] | |
| def _rerank_bge( | |
| self, | |
| query: str, | |
| docs: List[Dict[str, Any]], | |
| top_n: int, | |
| ) -> List[Dict[str, Any]]: | |
| """Rerank using BGE β reads first 512 chars only.""" | |
| if self._bge_model is None: | |
| self._load_bge() | |
| if self._bge_model is None: | |
| logger.warning("BGE unavailable β falling back to vector score ordering") | |
| return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] | |
| pairs = [] | |
| valid_docs = [] | |
| for doc in docs: | |
| content = doc.get("content", "").strip() | |
| if content: | |
| pairs.append([query, content[:self.MAX_CONTENT_CHARS_BGE]]) | |
| valid_docs.append(doc) | |
| if not pairs: | |
| return [] | |
| try: | |
| if self._use_flag: | |
| scores = self._bge_model.compute_score(pairs, batch_size=64) | |
| if isinstance(scores, float): | |
| scores = [scores] | |
| else: | |
| scores = self._bge_model.predict(pairs) | |
| if isinstance(scores, float): | |
| scores = [scores] | |
| for i, doc in enumerate(valid_docs): | |
| doc["rerank_score"] = float(scores[i]) | |
| valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True) | |
| logger.info( | |
| f"[Reranker] BGE: {len(valid_docs)} docs β top {top_n} " | |
| f"(max_score={valid_docs[0]['rerank_score']:.3f})" | |
| ) | |
| return valid_docs[:top_n] | |
| except Exception as e: | |
| logger.error(f"BGE reranking failed: {e} β falling back to vector score") | |
| return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] | |
| def model_type(self) -> str: | |
| return "jina_v3" if self._is_jina_v3 else "bge" | |